In [1]:
import os
import shutil

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 500
valid_datasets_sizes = 500
full_datasets_pos_samples_nb = 7500
full_datasets_neg_samples_nb = 7500
sample_nb_per_class = 100

N_JOBS = 20

In [2]:
# Number of images generated
NBGEN_full_per_size = 600000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [3]:
from xaipatimg.datagen.dbimg import load_db
db_patterns = load_db(db_patterns_dir)


In [4]:
import numpy as np
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["content"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2,2), "", dtype="U100")
        for entry in v["content"]:
            img_col_d[entry["color"]] = True
            img_shape_d[entry["shape"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["color"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [5]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},
    #
    {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "The image contains 13 blue circles. They are highlighted below. | A1;B2;B3;B5;C1;C3;C4;D1;D2;D3;D4;E2;E7", "neg_llm_scaffold": "The image contains 12 blue circles instead of 13. They are highlighted below. | A1;B2;B3;B5;C1;C3;C4;D1;D2;D3;D4;E2"},

    {"name": "hard2bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    {"name": "hard4_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "triangle",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "triangle",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard4bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "hard6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "hard7_square_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a square in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard8_square_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a square in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""}

]

In [6]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 6,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 6,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},
    #
    {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy4bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                  "x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": False},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    {"name": "easy5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},

    {"name": "easy6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

    {"name": "easy7_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    {"name": "easy8_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},

    {"name": "easy9_blue_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "square",
                                                                                                  "color": "#0C90C0",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": True},
     "question": "Does the number of blue squares equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy10_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "circle",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy11_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "circle",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy12_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "circle",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

]

In [7]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[8]},

    {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

     {"name": "med2_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of purple triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2_purple_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2bis_purple_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of purple squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med2_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "circle",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med2bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "circle",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},



    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True,
                                                                                                            },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[9]},

    {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med4bis_blue_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
                                                                                                "x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": False},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False,
                                                                                                     },
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    {"name": "med6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False,
                                                                                                     },
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},


    {"name": "med7_triangle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a triangle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med8_triangle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "exclude_two_rows_missing": True},
     "question": "Is there a triangle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]


In [8]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [13]:
from xaipatimg.datagen.jsondb import JSONDB
import tqdm
from os import makedirs
import pandas as pd
from pathlib import Path



def _get_all_img_ids_from_db_csv(dataset_csv):
    ids_list = []
    for path in dataset_csv["path"]:
        ids_list.append(Path(path).stem)
    return ids_list

def reduce_db(db_dir, datasets_path, dataset_names):
    """
    Saving a minimal version of the DB that contains all data for the datasets given as a list.
    :param db_dir: root directory of the datasets
    :param dataset_names: names of the datasets to extract the data for
    :return:
    """

    min_db_dir = os.path.join(db_dir, "min")
    makedirs(min_db_dir, exist_ok=True)
    makedirs(os.path.join(min_db_dir, "img"), exist_ok=True)

    all_keys_needed = {}

    # Iterating over datasets
    for dataset_name in dataset_names:

        ids_list = []

        ids_list.extend(_get_all_img_ids_from_db_csv(pd.read_csv(os.path.join(datasets_path, dataset_name + "_train.csv"))))
        ids_list.extend(_get_all_img_ids_from_db_csv(pd.read_csv(os.path.join(datasets_path, dataset_name + "_test.csv"))))
        ids_list.extend(_get_all_img_ids_from_db_csv(pd.read_csv(os.path.join(datasets_path, dataset_name + "_valid.csv"))))

        for img_id in ids_list:
            all_keys_needed[img_id] = True

    full_db = JSONDB(os.path.join(db_dir, "db.json"))
    reduced_db = JSONDB(os.path.join(min_db_dir, "db.json"))
    #
    # for k, v in tqdm.tqdm(full_db.items()):
    #     if k in all_keys_needed:
    #         reduced_db[k] = v
    #         shutil.copy(os.path.join(db_dir, "img", f"{k}.png"), os.path.join(min_db_dir, "img", f"{k}.png"))

    reduced_db.flush()


In [14]:
reduce_db(db_S_dir, datasets_path_S, [rule_data["name"] for rule_data in rules_data_S])

600000it [02:53, 3459.97it/s]


In [15]:
reduce_db(db_L_dir, datasets_path_L, [rule_data["name"] for rule_data in rules_data_L])

600000it [04:23, 2276.45it/s]


In [16]:
reduce_db(db_M_dir, datasets_path_M, [rule_data["name"] for rule_data in rules_data_M])

600000it [04:54, 2036.12it/s]
