In [1]:
import os
import shutil

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 500
valid_datasets_sizes = 500
full_datasets_pos_samples_nb = 7500
full_datasets_neg_samples_nb = 7500
sample_nb_per_class = 100

N_JOBS = 20

In [2]:
# Number of images generated
NBGEN_full_per_size = 600000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [3]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

def generate_db(db, x_divisions, y_divisions, to_generate, img_size):
    unique_content_generated = {}

    duplicate_count = 0
    while to_generate > 0:
        if to_generate%10000 == 0:
            print(f"{to_generate} to generate yet")
        content = []
        for i in range(x_divisions):
            for j in range(y_divisions):
                if np.random.random() < SHAPE_PROB:
                    content.append({
                        "shape": np.random.choice(SHAPES),
                        "pos": (i, j),
                        "color": np.random.choice(COLORS)
                    })

        if str(content) in unique_content_generated:
            duplicate_count += 1
            continue

        imgid = generate_uuid()
        db[imgid] = {
            "path": os.path.join("img", imgid + ".png"),
            "division" : (x_divisions, y_divisions),
            "size": img_size,
            "content": content
        }

        unique_content_generated[str(content)] = True
        to_generate -= 1

    print("unique generated in DB : " + str(len(db)))
    print("duplicates avoided : " + str(duplicate_count))
    return db

In [4]:
import tqdm

def check_for_duplicates(db):
    content_dict = {}
    nb_duplicates = 0

    for k, v in tqdm.tqdm(db.items()):
        if str(v["content"]) in content_dict:
            nb_duplicates += 1
        else:
            content_dict[str(v["content"])] = True

    print(nb_duplicates)

In [None]:
from xaipatimg.datagen.dbimg import load_db


In [None]:
from xaipatimg.datagen.dbimg import load_db

# db_S = load_db(db_S_dir)
# db_L = load_db(db_L_dir)
# db_M = load_db(db_M_dir)

In [None]:

db_patterns = load_db(db_patterns_dir)



#### Generate full DB

In [8]:

# db_L = generate_db(db_L, X_DIVISIONS_L, Y_DIVISIONS_L, NBGEN_full_per_size, img_size)

In [9]:

# check_for_duplicates(db_L)

In [10]:
# from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db_L, db_L_dir, overwrite=True, n_jobs=N_JOBS)
# db_L = None

In [11]:
# db_S = generate_db(db_S, X_DIVISIONS_S, Y_DIVISIONS_S, NBGEN_full_per_size, img_size)


In [12]:
# check_for_duplicates(db_S)

In [13]:
# gen_img_and_save_db(db_S, db_S_dir, overwrite=True, n_jobs=N_JOBS)
# db_S = None

In [14]:
# db_M = generate_db(db_M, X_DIVISIONS_M, Y_DIVISIONS_M, NBGEN_full_per_size, img_size)


In [15]:
# check_for_duplicates(db_M)

In [16]:
# gen_img_and_save_db(db_M, db_M_dir, overwrite=True, n_jobs=N_JOBS)
# db_M = None

#### Generate DB of patterns

In [17]:
# db_patterns = generate_db(db_patterns, X_DIVISIONS_PATTERNS, Y_DIVISIONS_PATTERNS, NBGEN_patterns, img_size_patterns)

In [18]:
# check_for_duplicates(db_patterns)

In [19]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db_patterns, db_patterns_dir, overwrite=True, draw_coordinates=False,
#                     empty_cell_as_question_mark=True, n_jobs=N_JOBS)

## Interface prototype v5

In [20]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")

In [21]:
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")

In [22]:
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [23]:
import numpy as np
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["content"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2,2), "", dtype="U100")
        for entry in v["content"]:
            img_col_d[entry["color"]] = True
            img_shape_d[entry["shape"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["color"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [5]:
print(len(pattern_3sym_2col_keys))

NameError: name 'pattern_3sym_2col_keys' is not defined

In [25]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

def generate_all_datasets(rules_data, db_dir, datasets_path):
    for rule_line in tqdm.tqdm(rules_data):
        name = rule_line["name"]
        sample_path = os.path.join(datasets_path, f"{name}_train")
        if "pattern_id" in rule_line:
            rule_line["gen_kwargs"]["pattern_content"] = db_patterns[rule_line["pattern_id"]]["content"]
        create_dataset_generic_rule_extract_sample(db_dir, datasets_dir_path=datasets_path, csv_name_train=name+"_train.csv",
                                                   csv_name_test=name+"_test.csv", csv_name_valid=name+"_valid.csv",
                                                   test_size=test_datasets_sizes,
                                                   valid_size=valid_datasets_sizes, dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                                   dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                                   sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                                   generic_rule_fun=rule_line["gen_fun"], n_jobs=N_JOBS,
                                                   **rule_line["gen_kwargs"])

## DB L datasets generation

In [26]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [
    #
    # {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},
    #
    {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "The image contains 13 blue circles. They are highlighted below. | A1;B1;C1;B2;C2;D2;E4;F4;G5;A6;B6;E6;D7", "neg_llm_scaffold": "The image contains 12 blue circles instead of 13. They are highlighted below. | A2;B3;D3;B3;C3;D3;E5;F6;A7;B7;E7;H9"},
    #
    # {"name": "hard2bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    # {"name": "hard4_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "hard4_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "The pattern is found in the image and is highlighted below. | B5;C5;B6", "neg_llm_scaffold": "The pattern was not found in the image. |", "pattern_id": pattern_3sym_2col_keys[2]},
    #
    # {"name": "hard6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    # {"name": "hard7_square_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a square in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard8_square_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a square in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""}

]

In [27]:
# generate_all_datasets(rules_data_L, db_L_dir, datasets_path_L)

## DB S datasets generation

In [28]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [
    #
    # {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 6,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 6,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},
    #
    {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy4bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                  "x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": False},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    {"name": "easy5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},

    {"name": "easy6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

    # {"name": "easy7_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    # {"name": "easy8_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},
    #
    # {"name": "easy9_blue_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
    #                                                                                               "y_division": Y_DIVISIONS_S,
    #                                                                                               "shape": "square",
    #                                                                                               "color": "#0C90C0",
    #                                                                                               "N": 6,
    #                                                                                               "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue squares equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "easy10_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "circle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy11_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "circle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy12_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "circle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

]

In [29]:
# generate_all_datasets(rules_data_S, db_S_dir, datasets_path_S)

## DB M datasets generation

In [30]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    # {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
    #                                                                                                  "y_division_full": Y_DIVISIONS_M,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[8]},
    #
    {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    #  {"name": "med2_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med2bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    # {"name": "med2_purple_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med2bis_purple_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "med2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "med2_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med2bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    #
    # {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
    #                                                                                                  "y_division_full": Y_DIVISIONS_M,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True,
    #                                                                                                         },
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[9]},
    #
    {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med4bis_blue_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False,
                                                                                                     },
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    {"name": "med6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False,
                                                                                                     },
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},

    #
    # {"name": "med7_triangle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a triangle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "med8_triangle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a triangle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]


In [31]:
generate_all_datasets(rules_data_M, db_M_dir, datasets_path_M)

  0%|          | 0/6 [00:00<?, ?it/s]
  6%|▌         | 33200/600000 [00:50<02:40, 3529.01it/s][A
  6%|▋         | 38320/600000 [00:50<01:53, 4944.52it/s][A
  7%|▋         | 43440/600000 [00:50<01:24, 6598.14it/s][A
  8%|▊         | 48560/600000 [00:50<01:04, 8530.28it/s][A
  9%|▉         | 53680/600000 [00:50<00:51, 10612.87it/s][A
 10%|▉         | 58800/600000 [00:51<00:42, 12689.56it/s][A
 11%|█         | 63920/600000 [00:51<00:37, 14452.52it/s][A
 12%|█▏        | 69040/600000 [00:51<00:33, 16089.55it/s][A
 12%|█▏        | 74160/600000 [00:51<00:30, 17520.44it/s][A
 13%|█▎        | 79280/600000 [00:52<00:28, 18479.63it/s][A
 14%|█▍        | 84400/600000 [00:52<00:26, 19140.80it/s][A
 15%|█▍        | 89520/600000 [00:52<00:25, 19972.34it/s][A
 16%|█▌        | 94640/600000 [00:52<00:25, 20077.34it/s][A
 17%|█▋        | 99760/600000 [00:53<00:24, 20357.00it/s][A
 17%|█▋        | 104880/600000 [00:53<00:23, 20779.89it/s][A
 18%|█▊        | 109861/600000 [00:53<00:19, 24880

Total number of positive instances found in database : 7500
Total number of negative instances found in database : 13782



0it [00:00, ?it/s][A
34it [00:00, 331.74it/s][A
70it [00:00, 341.19it/s][A
105it [00:00, 327.95it/s][A
139it [00:00, 328.33it/s][A
14001it [00:00, 23484.70it/s]
 83%|████████▎ | 5/6 [13:31<02:37, 157.88s/it]
  0%|          | 0/600000 [00:00<?, ?it/s][A
  0%|          | 440/600000 [00:00<02:18, 4318.94it/s][A
  0%|          | 1548/600000 [00:20<2:25:53, 68.36it/s][A
  0%|          | 1548/600000 [00:20<2:25:53, 68.36it/s][A
  0%|          | 1548/600000 [00:46<2:25:53, 68.36it/s][A
  0%|          | 2520/600000 [00:46<3:25:20, 48.50it/s][A
  0%|          | 2800/600000 [00:46<2:51:07, 58.16it/s][A
  1%|          | 5040/600000 [00:46<58:24, 169.78it/s] [A
  1%|▏         | 7600/600000 [00:47<28:14, 349.61it/s][A
  2%|▏         | 12571/600000 [00:47<11:36, 843.78it/s][A
  3%|▎         | 15366/600000 [00:47<07:58, 1222.79it/s][A
  3%|▎         | 18012/600000 [00:47<05:42, 1698.82it/s][A
  4%|▍         | 22960/600000 [00:47<03:19, 2888.35it/s][A
  5%|▍         | 28080/600000 [

Total number of positive instances found in database : 7500
Total number of negative instances found in database : 43938



14001it [00:00, 170865.58it/s]
100%|██████████| 6/6 [16:20<00:00, 163.47s/it]


In [32]:
import shutil

def setup_training_file_N_shape_color_rules(restrict_training_file, norestrict_training_file):
    """
    Archives (.old suffix) the training file with the restriction +1/-1 with the target number N.
    Copies the training file without the restriction in replacement of the training file with the restriction.
    The reason this operation is performed is that the models (ResNet18) are unable to learn to predict the for task if the
    only instances in the training datasets contain (N-1, N or N+1) occurrences of the symbol. The test and validation datasets
    are left untouched.
    """

    shutil.copyfile(restrict_training_file, restrict_training_file + ".old")
    shutil.copyfile(norestrict_training_file, restrict_training_file)



Switching train datasets so that the training dataset for tasks that consist in learning if there is exactly N times a symbol with given shape/color contains instances with any number of N (not only N-1, N, N+1 because in the latter case, the learning does
not converge).

In [33]:
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_yellow_square_N_train.csv"),
#                                         os.path.join(datasets_path_M, "med2bis_yellow_square_N_norestrict_train.csv"))

# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_purple_triangle_N_train.csv"),
#                                         os.path.join(datasets_path_M, "med2bis_purple_triangle_N_norestrict_train.csv")

setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_purple_square_N_train.csv"),
                                        os.path.join(datasets_path_M, "med2bis_purple_square_N_norestrict_train.csv"))

setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_purple_circle_N_train.csv"),
                                        os.path.join(datasets_path_M, "med2bis_purple_circle_N_norestrict_train.csv"))

setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med2_yellow_triangle_N_train.csv"),
                                        os.path.join(datasets_path_M, "med2bis_yellow_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_M, "med4_blue_triangle_N_train.csv"),
#                 os.path.join(datasets_path_M, "med4bis_blue_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_S, "easy2_yellow_triangle_N_train.csv"),
#                 os.path.join(datasets_path_S, "easy2bis_yellow_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_S, "easy4_purple_circle_N_train.csv"),
#                 os.path.join(datasets_path_S, "easy4bis_purple_circle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard2_blue_circle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard2bis_blue_circle_N_norestrict_train.csv"))

# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard4_purple_triangle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard4bis_purple_triangle_N_norestrict_train.csv"))

# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard4_yellow_triangle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard4bis_yellow_triangle_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard4_yellow_square_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard4bis_yellow_square_N_norestrict_train.csv"))
#
# setup_training_file_N_shape_color_rules(os.path.join(datasets_path_L, "hard4_blue_circle_N_train.csv"),
#                 os.path.join(datasets_path_L, "hard4bis_blue_circle_N_norestrict_train.csv"))
