In [2]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 1000
valid_datasets_sizes = 1000
full_datasets_pos_samples_nb = 5000
full_datasets_neg_samples_nb = 5000
sample_nb_per_class = 100

In [35]:
# Number of images generated
NBGEN_full_per_size = 150000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [4]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

def generate_db(db, x_divisions, y_divisions, to_generate, img_size):
    unique_content_generated = {}

    duplicate_count = 0
    while to_generate > 0:
        content = []
        for i in range(x_divisions):
            for j in range(y_divisions):
                if np.random.random() < SHAPE_PROB:
                    content.append({
                        "shape": np.random.choice(SHAPES),
                        "pos": (i, j),
                        "color": np.random.choice(COLORS)
                    })

        if str(content) in unique_content_generated:
            duplicate_count += 1
            continue

        imgid = generate_uuid()
        db[imgid] = {
            "path": os.path.join("img", imgid + ".png"),
            "division" : (x_divisions, y_divisions),
            "size": img_size,
            "content": content
        }

        unique_content_generated[str(content)] = True
        to_generate -= 1

    print("unique generated in DB : " + str(len(db)))
    print("duplicates avoided : " + str(duplicate_count))
    return db

In [5]:
import tqdm

def check_for_duplicates(db):
    content_dict = {}
    nb_duplicates = 0

    for k, v in tqdm.tqdm(db.items()):
        if str(v["content"]) in content_dict:
            nb_duplicates += 1
        else:
            content_dict[str(v["content"])] = True

    print(nb_duplicates)

In [6]:
from xaipatimg.datagen.dbimg import load_db


In [7]:
from xaipatimg.datagen.dbimg import load_db

db_S = load_db(db_S_dir)
db_L = load_db(db_L_dir)
db_M = load_db(db_M_dir)

In [8]:

db_patterns = load_db(db_patterns_dir)



#### Generate full DB

In [9]:

db_L = generate_db(db_L, X_DIVISIONS_L, Y_DIVISIONS_L, NBGEN_full_per_size, img_size)

unique generated in DB : 10000
duplicates avoided : 0


In [10]:

check_for_duplicates(db_L)

100%|██████████| 10000/10000 [00:01<00:00, 8695.33it/s]

0





In [11]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
gen_img_and_save_db(db_L, db_L_dir, overwrite=True, n_jobs=20)
db_L = None

100%|██████████| 10000/10000 [00:15<00:00, 651.56it/s]


In [12]:
db_S = generate_db(db_S, X_DIVISIONS_S, Y_DIVISIONS_S, NBGEN_full_per_size, img_size)


unique generated in DB : 10000
duplicates avoided : 0


In [13]:
check_for_duplicates(db_S)

100%|██████████| 10000/10000 [00:00<00:00, 21930.15it/s]

0





In [14]:
gen_img_and_save_db(db_S, db_S_dir, overwrite=True, n_jobs=20)
db_S = None

100%|██████████| 10000/10000 [00:11<00:00, 836.45it/s]


In [15]:
db_M = generate_db(db_M, X_DIVISIONS_M, Y_DIVISIONS_M, NBGEN_full_per_size, img_size)


unique generated in DB : 10000
duplicates avoided : 0


In [16]:
check_for_duplicates(db_M)

100%|██████████| 10000/10000 [00:00<00:00, 14613.96it/s]

0





In [17]:
gen_img_and_save_db(db_M, db_M_dir, overwrite=True, n_jobs=20)
db_M = None

100%|██████████| 10000/10000 [00:12<00:00, 769.84it/s]


#### Generate DB of patterns

In [18]:
db_patterns = generate_db(db_patterns, X_DIVISIONS_PATTERNS, Y_DIVISIONS_PATTERNS, NBGEN_patterns, img_size_patterns)

unique generated in DB : 1300
duplicates avoided : 1009


In [19]:
check_for_duplicates(db_patterns)

100%|██████████| 1300/1300 [00:00<00:00, 350401.34it/s]

41





In [20]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
gen_img_and_save_db(db_patterns, db_patterns_dir, overwrite=True, draw_coordinates=False, n_jobs=20)

100%|██████████| 1300/1300 [00:00<00:00, 10992.74it/s]


## Interface prototype v5

In [21]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")

In [22]:
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")

In [23]:
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [24]:
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors
for k, v in db_patterns.items():
    if len(v["content"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        for entry in v["content"]:
            img_col_d[entry["color"]] = True
            img_shape_d[entry["shape"]] = True

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3:
            pattern_3sym_2col_keys.append(k)

In [25]:
print(len(pattern_3sym_2col_keys))

73


In [36]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

def generate_all_datasets(rules_data, db_dir, datasets_path):
    for rule_line in tqdm.tqdm(rules_data):
        name = rule_line["name"]
        sample_path = os.path.join(datasets_path, f"{name}_train")
        if "pattern_id" in rule_line:
            rule_line["gen_kwargs"]["pattern_content"] = db_patterns[rule_line["pattern_id"]]["content"]
        create_dataset_generic_rule_extract_sample(db_dir, datasets_dir_path=datasets_path, csv_name_train=name+"_train.csv",
                                                   csv_name_test=name+"_test.csv", csv_name_valid=name+"_valid.csv",
                                                   test_size=test_datasets_sizes,
                                                   valid_size=valid_datasets_sizes, dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                                   dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                                   sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                                   generic_rule_fun=rule_line["gen_fun"],
                                                   **rule_line["gen_kwargs"])

## DB L datasets generation

In [37]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},

    {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    {"name": "hard4_purple_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [None]:
generate_all_datasets(rules_data_L, db_L_dir, datasets_path_L)

  0%|          | 0/3 [00:00<?, ?it/s]
  0%|          | 0/10000 [00:00<?, ?it/s][A
  4%|▍         | 398/10000 [00:00<00:02, 3972.12it/s][A
  8%|▊         | 796/10000 [00:00<00:02, 3794.77it/s][A
 12%|█▏        | 1177/10000 [00:00<00:02, 3725.65it/s][A
 16%|█▌        | 1551/10000 [00:00<00:02, 3728.86it/s][A
 19%|█▉        | 1925/10000 [00:00<00:02, 3716.55it/s][A
 23%|██▎       | 2301/10000 [00:00<00:02, 3730.47it/s][A
 27%|██▋       | 2697/10000 [00:00<00:01, 3804.27it/s][A
 31%|███       | 3108/10000 [00:00<00:01, 3899.94it/s][A
 35%|███▌      | 3509/10000 [00:00<00:01, 3932.51it/s][A
 39%|███▉      | 3903/10000 [00:01<00:01, 3919.57it/s][A
 43%|████▎     | 4296/10000 [00:01<00:01, 3876.72it/s][A
 47%|████▋     | 4684/10000 [00:01<00:01, 3871.42it/s][A
 51%|█████     | 5074/10000 [00:01<00:01, 3877.43it/s][A
 55%|█████▍    | 5464/10000 [00:01<00:01, 3881.74it/s][A
 59%|█████▊    | 5853/10000 [00:01<00:01, 3878.96it/s][A
 62%|██████▏   | 6241/10000 [00:01<00:00, 3868.88

Total number of positive instances found in database : 1159
Total number of negative instances found in database : 2166



1601it [00:00, 226256.97it/s]
 33%|███▎      | 1/3 [00:03<00:06,  3.44s/it]
  0%|          | 0/10000 [00:00<?, ?it/s][A
  0%|          | 31/10000 [00:00<00:32, 302.82it/s][A
  1%|          | 62/10000 [00:00<00:34, 287.55it/s][A
  1%|          | 92/10000 [00:00<00:33, 292.23it/s][A
  1%|          | 123/10000 [00:00<00:33, 297.58it/s][A
  2%|▏         | 154/10000 [00:00<00:32, 299.50it/s][A
  2%|▏         | 185/10000 [00:00<00:32, 299.32it/s][A
  2%|▏         | 215/10000 [00:00<00:32, 297.79it/s][A
  2%|▏         | 245/10000 [00:00<00:32, 296.71it/s][A
  3%|▎         | 275/10000 [00:00<00:32, 296.96it/s][A
  3%|▎         | 305/10000 [00:01<00:32, 296.44it/s][A
  3%|▎         | 335/10000 [00:01<00:32, 297.17it/s][A
  4%|▎         | 365/10000 [00:01<00:33, 291.07it/s][A
  4%|▍         | 395/10000 [00:01<00:33, 288.56it/s][A
  4%|▍         | 425/10000 [00:01<00:33, 290.09it/s][A
  5%|▍         | 455/10000 [00:01<00:33, 282.22it/s][A
  5%|▍         | 484/10000 [00:01<00:33, 

## DB S datasets generation

In [38]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 6,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [39]:
generate_all_datasets(rules_data_S, db_S_dir, datasets_path_S)

## DB M datasets generation

In [1]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True,
                                                                                                            },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

    {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]


NameError: name 'X_DIVISIONS_L' is not defined

In [41]:
generate_all_datasets(rules_data_M, db_M_dir, datasets_path_M)

  0%|          | 0/4 [00:01<?, ?it/s]


KeyboardInterrupt: 