In [32]:
import os
import shutil

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 500
valid_datasets_sizes = 500
full_datasets_pos_samples_nb = 7500
full_datasets_neg_samples_nb = 7500
sample_nb_per_class = 100

In [33]:
# Number of images generated
NBGEN_full_per_size = 200000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [34]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

def generate_db(db, x_divisions, y_divisions, to_generate, img_size):
    unique_content_generated = {}

    duplicate_count = 0
    while to_generate > 0:
        content = []
        for i in range(x_divisions):
            for j in range(y_divisions):
                if np.random.random() < SHAPE_PROB:
                    content.append({
                        "shape": np.random.choice(SHAPES),
                        "pos": (i, j),
                        "color": np.random.choice(COLORS)
                    })

        if str(content) in unique_content_generated:
            duplicate_count += 1
            continue

        imgid = generate_uuid()
        db[imgid] = {
            "path": os.path.join("img", imgid + ".png"),
            "division" : (x_divisions, y_divisions),
            "size": img_size,
            "content": content
        }

        unique_content_generated[str(content)] = True
        to_generate -= 1

    print("unique generated in DB : " + str(len(db)))
    print("duplicates avoided : " + str(duplicate_count))
    return db

In [35]:
import tqdm

def check_for_duplicates(db):
    content_dict = {}
    nb_duplicates = 0

    for k, v in tqdm.tqdm(db.items()):
        if str(v["content"]) in content_dict:
            nb_duplicates += 1
        else:
            content_dict[str(v["content"])] = True

    print(nb_duplicates)

In [36]:
from xaipatimg.datagen.dbimg import load_db


In [37]:
from xaipatimg.datagen.dbimg import load_db

# db_S = load_db(db_S_dir)
# db_L = load_db(db_L_dir)
# db_M = load_db(db_M_dir)

In [38]:

db_patterns = load_db(db_patterns_dir)



#### Generate full DB

In [39]:

# db_L = generate_db(db_L, X_DIVISIONS_L, Y_DIVISIONS_L, NBGEN_full_per_size, img_size)

In [40]:

# check_for_duplicates(db_L)

In [41]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db_L, db_L_dir, overwrite=True, n_jobs=20)
# db_L = None

In [42]:
# db_S = generate_db(db_S, X_DIVISIONS_S, Y_DIVISIONS_S, NBGEN_full_per_size, img_size)


In [43]:
# check_for_duplicates(db_S)

In [44]:
# gen_img_and_save_db(db_S, db_S_dir, overwrite=True, n_jobs=20)
# db_S = None

In [45]:
# db_M = generate_db(db_M, X_DIVISIONS_M, Y_DIVISIONS_M, NBGEN_full_per_size, img_size)


In [46]:
# check_for_duplicates(db_M)

In [47]:
# gen_img_and_save_db(db_M, db_M_dir, overwrite=True, n_jobs=20)
# db_M = None

#### Generate DB of patterns

In [48]:
#db_patterns = generate_db(db_patterns, X_DIVISIONS_PATTERNS, Y_DIVISIONS_PATTERNS, NBGEN_patterns, img_size_patterns)

In [49]:
#check_for_duplicates(db_patterns)

In [50]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
#gen_img_and_save_db(db_patterns, db_patterns_dir, overwrite=True, draw_coordinates=False, n_jobs=20)

## Interface prototype v5

In [51]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")

In [52]:
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")

In [53]:
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [54]:
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors
for k, v in db_patterns.items():
    if len(v["content"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        for entry in v["content"]:
            img_col_d[entry["color"]] = True
            img_shape_d[entry["shape"]] = True

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3:
            pattern_3sym_2col_keys.append(k)

In [55]:
print(len(pattern_3sym_2col_keys))

63


In [56]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

def generate_all_datasets(rules_data, db_dir, datasets_path):
    for rule_line in tqdm.tqdm(rules_data):
        name = rule_line["name"]
        sample_path = os.path.join(datasets_path, f"{name}_train")
        if "pattern_id" in rule_line:
            rule_line["gen_kwargs"]["pattern_content"] = db_patterns[rule_line["pattern_id"]]["content"]
        create_dataset_generic_rule_extract_sample(db_dir, datasets_dir_path=datasets_path, csv_name_train=name+"_train.csv",
                                                   csv_name_test=name+"_test.csv", csv_name_valid=name+"_valid.csv",
                                                   test_size=test_datasets_sizes,
                                                   valid_size=valid_datasets_sizes, dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                                   dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                                   sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                                   generic_rule_fun=rule_line["gen_fun"],
                                                   **rule_line["gen_kwargs"])

## DB L datasets generation

In [57]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},

    {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 16,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue circles equal to 16 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard2bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 16,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of blue circles equal to 16 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    {"name": "hard4_purple_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 16,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple squares equal to 16 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4bis_purple_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 16,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of purple squares equal to 16 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [58]:
#generate_all_datasets(rules_data_L, db_L_dir, datasets_path_L)

## DB S datasets generation

In [59]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 8,
                                                                                                  "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy4bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                  "x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 8,
                                                                                                  "restrict_plus_minus_1": False},
     "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [60]:
#generate_all_datasets(rules_data_S, db_S_dir, datasets_path_S)

## DB M datasets generation

In [61]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    # {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},
    #
    # {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 11,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow squares equal to 11 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med2bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 11,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow squares equal to 11 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True,
    #                                                                                                         },
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},
    #
    # {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 11,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue triangles equal to 11 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med4bis_blue_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 11,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of blue triangles equal to 11 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]


In [62]:
generate_all_datasets(rules_data_M, db_M_dir, datasets_path_M)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/200000 [00:00<?, ?it/s][A
  0%|          | 301/200000 [00:00<01:06, 3001.50it/s][A
  0%|          | 602/200000 [00:00<01:06, 2979.97it/s][A
  0%|          | 901/200000 [00:00<01:08, 2924.24it/s][A
  1%|          | 1204/200000 [00:00<01:07, 2963.40it/s][A
  1%|          | 1509/200000 [00:00<01:06, 2994.01it/s][A
  1%|          | 1813/200000 [00:00<01:05, 3009.28it/s][A
  1%|          | 2120/200000 [00:00<01:05, 3027.58it/s][A
  1%|          | 2423/200000 [00:00<01:05, 3016.11it/s][A
  1%|▏         | 2726/200000 [00:00<01:05, 3019.82it/s][A
  2%|▏         | 3031/200000 [00:01<01:05, 3026.78it/s][A
  2%|▏         | 3335/200000 [00:01<01:04, 3030.27it/s][A
  2%|▏         | 3639/200000 [00:01<01:04, 3032.08it/s][A
  2%|▏         | 3945/200000 [00:01<01:04, 3037.88it/s][A
  2%|▏         | 4249/200000 [00:01<01:04, 3037.33it/s][A
  2%|▏         | 4553/200000 [00:01<01:04, 3036.83it/s][A
  2%|▏         | 4857/200000 [00:0

Total number of positive instances found in database : 7500
Total number of negative instances found in database : 97693



0it [00:00, ?it/s][A
54it [00:00, 511.96it/s][A
14001it [00:00, 47677.48it/s]
100%|██████████| 1/1 [01:02<00:00, 62.23s/it]


In [63]:
import shutil

def setup_training_file_N_shape_color_rules(restrict_training_file, norestrict_training_file):
    """
    Archives (.old suffix) the training file with the restriction +1/-1 with the target number N.
    Copies the training file without the restriction in replacement of the training file with the restriction.
    The reason this operation is performed is that the models (ResNet18) are unable to learn to predict the for task if the
    only instances in the training datasets contain (N-1, N or N+1) occurrences of the symbol. The test and validation datasets
    are left untouched.
    """

    shutil.copyfile(restrict_training_file, restrict_training_file + ".old")
    shutil.copyfile(norestrict_training_file, restrict_training_file)



Switching train datasets so that the training dataset for tasks that consist in learning if there is exactly N times a symbol with given shape/color contains instances with any number of N (not only N-1, N, N+1 because in the latter case, the learning does
not converge).

In [64]:
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_yellow_square_N_train.csv"),
#                                         os.path.join(datasets_path_M, "med2bis_yellow_square_N_norestrict_train.csv"))

setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med4_blue_triangle_N_train.csv"),
                os.path.join(datasets_path_M, "med4bis_blue_triangle_N_norestrict_train.csv"))

# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_S, "easy2_yellow_triangle_N_train.csv"),
#                 os.path.join(datasets_path_S, "easy2bis_yellow_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_S, "easy4_purple_circle_N_train.csv"),
#                 os.path.join(datasets_path_S, "easy4bis_purple_circle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard2_blue_circle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard2bis_blue_circle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard4_purple_square_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard4bis_purple_square_N_norestrict_train.csv"))