In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
os.makedirs(db_dir, exist_ok=True)

test_datasets_sizes=1000
valid_datasets_sizes=1000
full_datasets_pos_samples_nb=5000
full_datasets_neg_samples_nb=5000
sample_nb_per_class = 100

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape

rules_data = [
    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 6 blue symbols?"},
    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only circles?"},
    # {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 purple symbols?"},
    # {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?"},
    # {"name": "easy_5_7_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 yellow symbols?"},
    # {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only squares?"},
    #
    # {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
    #  "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?"},
    # {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only purple symbols, and one column (A, ..., F) containing only triangles?"},
    # {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?"},
    # {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only yellow symbols, and one column (A, ..., F) containing only circles?"},
    # {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?"},
    # {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only blue symbols, and one column (A, ..., F) containing only squares?"},
]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [5]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

to_generate = NBGEN
unique_content_generated = {}
duplicate_count = 0
while to_generate > 0:
    content = []
    for i in range(X_DIVISIONS):
        for j in range(Y_DIVISIONS):
            if np.random.random() < SHAPE_PROB:
                content.append({
                    "shape": np.random.choice(SHAPES),
                    "pos": (i, j),
                    "color": np.random.choice(COLORS)
                })

    if str(content) in unique_content_generated:
        duplicate_count += 1
        continue

    imgid = generate_uuid()
    db[imgid] = {
        "path": os.path.join("img", imgid + ".png"),
        "division" : (X_DIVISIONS, Y_DIVISIONS),
        "size": img_size,
        "content": content
    }

    unique_content_generated[str(content)] = True
    to_generate -= 1

print("unique generated in DB : " + str(len(db)))
print("duplicates avoided : " + str(duplicate_count))

unique generated in DB : 1000000
duplicates avoided : 0


In [6]:
import tqdm

content_dict = {}
nb_duplicates = 0

for k, v in tqdm.tqdm(db.items()):
    if str(v["content"]) in content_dict:
        nb_duplicates += 1
    else:
        content_dict[str(v["content"])] = True

print(nb_duplicates)

100%|██████████| 1000000/1000000 [00:38<00:00, 25679.41it/s]

0





In [7]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db, db_dir, overwrite=True, n_jobs=20)

100%|██████████| 1000000/1000000 [12:23<00:00, 1345.24it/s]


In [5]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

for rule_line in tqdm.tqdm(rules_data):
    name = rule_line["name"]
    sample_path = os.path.join(db_dir, "datasets", f"{name}_train")
    create_dataset_generic_rule_extract_sample(db_dir, csv_name_train=name+"_train.csv", csv_name_test=name+"_test.csv",
                                               csv_name_valid=name+"_valid.csv", test_size=test_datasets_sizes, valid_size=valid_datasets_sizes,
                                               dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                               dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                               sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                               generic_rule_fun=rule_line["gen_fun"], **rule_line["gen_kwargs"])

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  0%|          | 4499/1000000 [00:00<00:22, 44981.25it/s][A
  1%|          | 8998/1000000 [00:00<00:22, 44836.09it/s][A
  1%|▏         | 13498/1000000 [00:00<00:21, 44906.71it/s][A
  2%|▏         | 17989/1000000 [00:00<00:21, 44845.84it/s][A
  2%|▏         | 22474/1000000 [00:00<00:21, 44823.42it/s][A
  3%|▎         | 26957/1000000 [00:00<00:21, 44253.60it/s][A
  3%|▎         | 31441/1000000 [00:00<00:21, 44440.02it/s][A
  4%|▎         | 35938/1000000 [00:00<00:21, 44604.44it/s][A
  4%|▍         | 40442/1000000 [00:00<00:21, 44738.63it/s][A
  4%|▍         | 44930/1000000 [00:01<00:21, 44779.88it/s][A
  5%|▍         | 49409/1000000 [00:01<00:21, 44289.96it/s][A
  5%|▌         | 53840/1000000 [00:01<00:21, 44247.59it/s][A
  6%|▌         | 58266/1000000 [00:01<00:21, 44197.24it/s][A
  6%|▋         | 62706/1000000 [00:01<00:21, 44256.30it/s][A
  7%|▋         | 67133/1000000 [00:01<00:21, 4414

Total number of positive instances found in database : 144633
Total number of negative instances found in database : 855367



0it [00:00, ?it/s][A
25it [00:00, 247.44it/s][A
61it [00:00, 310.74it/s][A
93it [00:00, 271.45it/s][A
121it [00:00, 266.99it/s][A
157it [00:00, 297.81it/s][A
8001it [00:00, 10690.08it/s]
 20%|██        | 1/5 [01:06<04:25, 66.31s/it]
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  0%|          | 1195/1000000 [00:00<01:23, 11938.11it/s][A
  0%|          | 2389/1000000 [00:00<01:24, 11874.29it/s][A
  0%|          | 3577/1000000 [00:00<01:25, 11687.46it/s][A
  0%|          | 4776/1000000 [00:00<01:24, 11801.77it/s][A
  1%|          | 5962/1000000 [00:00<01:24, 11819.05it/s][A
  1%|          | 7160/1000000 [00:00<01:23, 11871.61it/s][A
  1%|          | 8352/1000000 [00:00<01:23, 11886.29it/s][A
  1%|          | 9541/1000000 [00:00<01:24, 11738.10it/s][A
  1%|          | 10725/1000000 [00:00<01:24, 11769.44it/s][A
  1%|          | 11918/1000000 [00:01<01:23, 11817.06it/s][A
  1%|▏         | 13107/1000000 [00:01<01:23, 11838.96it/s][A
  1%|▏         | 14292/1000000 [00:01<0

Total number of positive instances found in database : 140854
Total number of negative instances found in database : 859146



0it [00:00, ?it/s][A
25it [00:00, 239.61it/s][A
49it [00:00, 203.40it/s][A
75it [00:00, 222.35it/s][A
100it [00:00, 231.87it/s][A
130it [00:00, 250.00it/s][A
156it [00:00, 248.86it/s][A
8001it [00:00, 9477.96it/s]A
 40%|████      | 2/5 [03:17<05:13, 104.63s/it]
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  0%|          | 4308/1000000 [00:00<00:23, 43071.42it/s][A
  1%|          | 8666/1000000 [00:00<00:22, 43366.75it/s][A
  1%|▏         | 13045/1000000 [00:00<00:22, 43558.78it/s][A
  2%|▏         | 17401/1000000 [00:00<00:22, 43541.41it/s][A
  2%|▏         | 21756/1000000 [00:00<00:22, 43463.01it/s][A
  3%|▎         | 26167/1000000 [00:00<00:22, 43681.52it/s][A
  3%|▎         | 30556/1000000 [00:00<00:22, 43748.23it/s][A
  3%|▎         | 34931/1000000 [00:00<00:22, 43643.11it/s][A
  4%|▍         | 39296/1000000 [00:00<00:22, 43575.79it/s][A
  4%|▍         | 43654/1000000 [00:01<00:22, 43123.30it/s][A
  5%|▍         | 48053/1000000 [00:01<00:21, 43384.32it/s][A
  5

Total number of positive instances found in database : 144170
Total number of negative instances found in database : 855830



0it [00:00, ?it/s][A
17it [00:00, 164.71it/s][A
43it [00:00, 217.53it/s][A
65it [00:00, 209.35it/s][A
92it [00:00, 229.76it/s][A
123it [00:00, 257.53it/s][A
155it [00:00, 277.70it/s][A
8001it [00:00, 9830.36it/s]A
 60%|██████    | 3/5 [04:24<02:54, 87.42s/it] 
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  0%|          | 1186/1000000 [00:00<01:24, 11857.86it/s][A
  0%|          | 2372/1000000 [00:00<01:24, 11839.83it/s][A
  0%|          | 3556/1000000 [00:00<01:25, 11608.14it/s][A
  0%|          | 4743/1000000 [00:00<01:25, 11707.60it/s][A
  1%|          | 5915/1000000 [00:00<01:25, 11662.67it/s][A
  1%|          | 7084/1000000 [00:00<01:25, 11669.76it/s][A
  1%|          | 8266/1000000 [00:00<01:24, 11717.96it/s][A
  1%|          | 9453/1000000 [00:00<01:24, 11765.19it/s][A
  1%|          | 10630/1000000 [00:00<01:24, 11757.58it/s][A
  1%|          | 11818/1000000 [00:01<01:23, 11794.61it/s][A
  1%|▏         | 12998/1000000 [00:01<01:23, 11780.21it/s][A
  1%|▏    

Total number of positive instances found in database : 152648
Total number of negative instances found in database : 847352



0it [00:00, ?it/s][A
20it [00:00, 196.92it/s][A
55it [00:00, 285.15it/s][A
84it [00:00, 261.53it/s][A
111it [00:00, 251.76it/s][A
137it [00:00, 241.30it/s][A
162it [00:00, 235.07it/s][A
8001it [00:00, 9909.75it/s]A
 80%|████████  | 4/5 [06:30<01:42, 102.74s/it]
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  0%|          | 4356/1000000 [00:00<00:22, 43547.37it/s][A
  1%|          | 8748/1000000 [00:00<00:22, 43765.83it/s][A
  1%|▏         | 13125/1000000 [00:00<00:22, 43584.19it/s][A
  2%|▏         | 17506/1000000 [00:00<00:22, 43670.87it/s][A
  2%|▏         | 21924/1000000 [00:00<00:22, 43852.06it/s][A
  3%|▎         | 26314/1000000 [00:00<00:22, 43866.72it/s][A
  3%|▎         | 30734/1000000 [00:00<00:22, 43974.58it/s][A
  4%|▎         | 35137/1000000 [00:00<00:21, 43989.66it/s][A
  4%|▍         | 39555/1000000 [00:00<00:21, 44046.88it/s][A
  4%|▍         | 43960/1000000 [00:01<00:21, 43971.47it/s][A
  5%|▍         | 48394/1000000 [00:01<00:21, 44083.76it/s][A
  5

Total number of positive instances found in database : 143712
Total number of negative instances found in database : 856288



0it [00:00, ?it/s][A
32it [00:00, 308.45it/s][A
63it [00:00, 251.36it/s][A
89it [00:00, 247.43it/s][A
116it [00:00, 254.05it/s][A
147it [00:00, 272.11it/s][A
8001it [00:00, 10824.68it/s]
100%|██████████| 5/5 [07:32<00:00, 90.47s/it] 
