In [1]:
import os
import shutil

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/M/"
db_XS_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/XS/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_XS_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 500
valid_datasets_sizes = 500
full_datasets_pos_samples_nb = 25000
full_datasets_neg_samples_nb = 25000

# full_datasets_pos_samples_nb = 1000
# full_datasets_neg_samples_nb = 1000
sample_nb_per_class = 100

N_JOBS = 20

In [2]:
# Number of images generated
NBGEN_full_per_size = 5000000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 9
Y_DIVISIONS_S = 9
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12
X_DIVISIONS_XS = 5
Y_DIVISIONS_XS = 5

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['c', 's', 't']
COLORS  = ["p", "y", "b"]

In [3]:
from xaipatimg.datagen.dbimg import generate_uuid
import os
import numpy as np

def generate_db(db, x_divisions, y_divisions, to_generate, img_size):

    content_hash_dict = {}
    duplicate_count = 0

    while to_generate > 0:
        if to_generate%1000000 == 0:
            print(f"{to_generate} to generate yet")
            db.flush()
        content = []
        for i in range(x_divisions):
            for j in range(y_divisions):
                if np.random.random() < SHAPE_PROB:
                    content.append({
                        "shp": np.random.choice(SHAPES),
                        "pos": (i, j),
                        "col": np.random.choice(COLORS)
                    })

        hashed = hash(str(content))
        if hashed in content_hash_dict:
            duplicate_count += 1
            continue
        content_hash_dict[hashed] = True

        imgid = generate_uuid()
        db[imgid] = {
            "path": os.path.join("img", imgid + ".png"),
            "div" : (x_divisions, y_divisions),
            "size": img_size,
            "cnt": content
        }
        to_generate -= 1

    db.flush()
    print("unique generated in DB : " + str(len(db)))
    print("duplicate avoided : " + str(duplicate_count))
    return db

In [4]:

def check_for_duplicates(db):
    content_hash_dict = {}
    nb_possible_duplicates = 0

    for k, v in tqdm.tqdm(db.items()):
        hashed = hash(str(v))
        if hashed in content_hash_dict:
            nb_possible_duplicates += 1
        else:
            content_hash_dict[hashed] = True

    print(f"Number of possible duplicate images : {nb_possible_duplicates}")

In [5]:
from xaipatimg.datagen.jsondb import JSONDB

db_S = JSONDB(os.path.join(db_S_dir, "db.json"))


loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/S/db.json


2317500343it [11:47, 3273448.96it/s]


In [6]:
from xaipatimg.datagen.jsondb import JSONDB

db_XS = JSONDB(os.path.join(db_XS_dir, "db.json"))

loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/XS/db.json


23317995it [00:07, 3036991.65it/s]


In [7]:
db_L = JSONDB(os.path.join(db_L_dir, "db.json"))
db_M = JSONDB(os.path.join(db_M_dir, "db.json"))

loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/L/db.json


2632441808it [13:36, 3222186.00it/s]

KeyboardInterrupt



In [6]:

db_patterns = JSONDB(os.path.join(db_patterns_dir, "db.json"))



loading keys for /home/docker/data/PatImgXAI_data/db3.1.0/patterns/db.json


47493it [00:00, 2194534.38it/s]


#### Generate full DB

In [7]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# db_L = generate_db(db_L, X_DIVISIONS_L, Y_DIVISIONS_L, NBGEN_full_per_size, img_size)

In [8]:
# db_S = generate_db(db_S, X_DIVISIONS_S, Y_DIVISIONS_S, NBGEN_full_per_size, img_size)

In [9]:
# db_M = generate_db(db_M, X_DIVISIONS_M, Y_DIVISIONS_M, NBGEN_full_per_size, img_size)


In [33]:
# db_XS = generate_db(db_XS, X_DIVISIONS_XS, Y_DIVISIONS_XS, 150000, img_size)


21607768it [00:06, 3378214.81it/s]


unique generated in DB : 100000
duplicate avoided : 0


#### Generate DB of patterns

In [11]:
# db_patterns = generate_db(db_patterns, X_DIVISIONS_PATTERNS, Y_DIVISIONS_PATTERNS, NBGEN_patterns, img_size_patterns)
# gen_img_and_save_db(db_patterns, db_patterns_dir, overwrite=True, draw_coordinates=False, empty_cell_as_question_mark=True, n_jobs=N_JOBS, do_save_db=False)


In [12]:
from xaipatimg.datagen.genimg import gen_img_and_save_db

In [13]:
# check_for_duplicates(db_L)
# check_for_duplicates(db_S)
# check_for_duplicates(db_M)
# check_for_duplicates(db_patterns)


## Interface prototype v5

In [14]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_expv1")

In [15]:
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_expv1")

In [34]:
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_expv1")
datasets_path_XS = os.path.join(db_XS_dir, "datasets", "01_expv1")


In [35]:
import numpy as np
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["cnt"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2,2), "", dtype="U100")
        for entry in v["cnt"]:
            img_col_d[entry["col"]] = True
            img_shape_d[entry["shp"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["col"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [36]:
print(len(pattern_3sym_2col_keys))

47


In [37]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

def generate_all_datasets(rules_data, db_dir, db, datasets_path):
    for rule_line in tqdm.tqdm(rules_data):
        name = rule_line["name"]
        sample_path = os.path.join(datasets_path, f"{name}_train")
        if "pattern_id" in rule_line:
            rule_line["gen_kwargs"]["pattern_content"] = db_patterns[rule_line["pattern_id"]]["cnt"]
        create_dataset_generic_rule_extract_sample(db_dir, db, datasets_dir_path=datasets_path, csv_name_train=name+"_train.csv",
                                                   csv_name_test=name+"_test.csv", csv_name_valid=name+"_valid.csv",
                                                   test_size=test_datasets_sizes,
                                                   valid_size=valid_datasets_sizes, dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                                   dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                                   sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                                   generic_rule_fun=rule_line["gen_fun"], n_jobs=N_JOBS, extract_sample=False,
                                                   **rule_line["gen_kwargs"])

## DB L datasets generation

In [20]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},


    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

]

In [21]:
generate_all_datasets(rules_data_L, db_L_dir, db_L, datasets_path_L)

NameError: name 'db_L' is not defined

## DB S datasets generation

In [23]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                         "y_division_full": Y_DIVISIONS_S,
                                                                                                         "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                         "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                         "consider_rotations": True},
         "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                                   "y_division_full": Y_DIVISIONS_S,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

]

In [24]:
generate_all_datasets(rules_data_S, db_S_dir, db_S, datasets_path_S)

  0%|          | 0/1 [00:00<?, ?it/s]

Total number of positive instances found in database : 25000
Total number of negative instances found in database : 25000


100%|██████████| 1/1 [33:06<00:00, 1986.46s/it]


## DB M datasets generation

In [None]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True,
                                                                                                            },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "med5_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},


]


In [None]:
generate_all_datasets(rules_data_M, db_M_dir, db_M, datasets_path_M)

In [38]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_XS = [
    {"name": "xeasy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_XS,
                                                                                                                   "y_division_full": Y_DIVISIONS_XS,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.5, "shown_acc" : 0.5, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

]

In [39]:
# generate_all_datasets(rules_data_XS, db_XS_dir, db_XS, datasets_path_XS)


100%|██████████| 1/1 [01:17<00:00, 77.93s/it]

Total number of positive instances found in database : 1000
Total number of negative instances found in database : 1000





In [None]:
import shutil

def setup_training_file_N_shape_color_rules(restrict_training_file, norestrict_training_file):
    """
    Archives (.old suffix) the training file with the restriction +1/-1 with the target number N.
    Copies the training file without the restriction in replacement of the training file with the restriction.
    The reason this operation is performed is that the models (ResNet18) are unable to learn to predict the for task if the
    only instances in the training datasets contain (N-1, N or N+1) occurrences of the symbol. The test and validation datasets
    are left untouched.
    """

    shutil.copyfile(restrict_training_file, restrict_training_file + ".old")
    shutil.copyfile(norestrict_training_file, restrict_training_file)



In [None]:
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_yellow_square_N_train.csv"),
#                                         os.path.join(datasets_path_M, "med2bis_yellow_square_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med4_blue_triangle_N_train.csv"),
#                 os.path.join(datasets_path_M, "med4bis_blue_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_S, "easy2_yellow_triangle_N_train.csv"),
#                 os.path.join(datasets_path_S, "easy2bis_yellow_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_S, "easy4_purple_circle_N_train.csv"),
#                 os.path.join(datasets_path_S, "easy4bis_purple_circle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard2_blue_circle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard2bis_blue_circle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard4_purple_triangle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard4bis_purple_triangle_N_norestrict_train.csv"))
