In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
os.makedirs(db_dir, exist_ok=True)

test_datasets_sizes=1000
valid_datasets_sizes=1000
full_datasets_pos_samples_nb=5000
full_datasets_neg_samples_nb=5000
sample_nb_per_class = 100

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape

rules_data = [
    {"name": "easy_1_row_circles", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS}},
    {"name": "easy_1_row_triangles", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}}
]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [5]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

to_generate = NBGEN
unique_content_generated = {}
duplicate_count = 0
while to_generate > 0:
    content = []
    for i in range(X_DIVISIONS):
        for j in range(Y_DIVISIONS):
            if np.random.random() < SHAPE_PROB:
                content.append({
                    "shape": np.random.choice(SHAPES),
                    "pos": (i, j),
                    "color": np.random.choice(COLORS)
                })

    if str(content) in unique_content_generated:
        duplicate_count += 1
        continue

    imgid = generate_uuid()
    db[imgid] = {
        "path": os.path.join("img", imgid + ".png"),
        "division" : (X_DIVISIONS, Y_DIVISIONS),
        "size": img_size,
        "content": content
    }

    unique_content_generated[str(content)] = True
    to_generate -= 1

print("unique generated in DB : " + str(len(db)))
print("duplicates avoided : " + str(duplicate_count))

unique generated in DB : 1000000
duplicates avoided : 0


In [6]:
import tqdm

content_dict = {}
nb_duplicates = 0

for k, v in tqdm.tqdm(db.items()):
    if str(v["content"]) in content_dict:
        nb_duplicates += 1
    else:
        content_dict[str(v["content"])] = True

print(nb_duplicates)

100%|██████████| 1000000/1000000 [00:38<00:00, 25679.41it/s]

0





In [7]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
gen_img_and_save_db(db, db_dir, overwrite=True, n_jobs=20)

100%|██████████| 1000000/1000000 [12:23<00:00, 1345.24it/s]


In [8]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm
for rule_line in tqdm.tqdm(rules_data):
    name = rule_line["name"]
    sample_path = os.path.join(db_dir, "datasets", f"{name}_train")
    create_dataset_generic_rule_extract_sample(db_dir, csv_name_train=name+"_train.csv", csv_name_test=name+"_test.csv",
                                               csv_name_valid=name+"_valid.csv", test_size=test_datasets_sizes, valid_size=valid_datasets_sizes,
                                               dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                               dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                               sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                               generic_rule_fun=rule_line["gen_fun"], **rule_line["gen_kwargs"])

  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  1%|          | 6100/1000000 [00:00<00:16, 60986.24it/s][A
  1%|          | 12199/1000000 [00:00<00:16, 60842.42it/s][A
  2%|▏         | 18284/1000000 [00:00<00:16, 60557.47it/s][A
  2%|▏         | 24368/1000000 [00:00<00:16, 60666.57it/s][A
  3%|▎         | 30454/1000000 [00:00<00:15, 60733.00it/s][A
  4%|▎         | 36577/1000000 [00:00<00:15, 60898.77it/s][A
  4%|▍         | 42752/1000000 [00:00<00:15, 61173.29it/s][A
  5%|▍         | 48880/1000000 [00:00<00:15, 61203.35it/s][A
  6%|▌         | 55001/1000000 [00:00<00:15, 60760.90it/s][A
  6%|▌         | 61078/1000000 [00:01<00:15, 60643.63it/s][A
  7%|▋         | 67178/1000000 [00:01<00:15, 60748.36it/s][A
  7%|▋         | 73307/1000000 [00:01<00:15, 60909.43it/s][A
  8%|▊         | 79399/1000000 [00:01<00:15, 60767.52it/s][A
  9%|▊         | 85486/1000000 [00:01<00:15, 60795.29it/s][A
  9%|▉         | 91638/1000000 [00:01<00:14, 610

Total number of positive instances found in database : 361872
Total number of negative instances found in database : 638128



8001it [00:00, 273147.48it/s]
 50%|█████     | 1/2 [00:52<00:52, 52.20s/it]
  0%|          | 0/1000000 [00:00<?, ?it/s][A
  1%|          | 6004/1000000 [00:00<00:16, 60027.89it/s][A
  1%|          | 12007/1000000 [00:00<00:16, 59916.07it/s][A
  2%|▏         | 18037/1000000 [00:00<00:16, 60087.04it/s][A
  2%|▏         | 24046/1000000 [00:00<00:16, 60033.06it/s][A
  3%|▎         | 30050/1000000 [00:00<00:16, 59896.27it/s][A
  4%|▎         | 36068/1000000 [00:00<00:16, 59989.30it/s][A
  4%|▍         | 42067/1000000 [00:00<00:16, 59416.02it/s][A
  5%|▍         | 48023/1000000 [00:00<00:16, 59459.38it/s][A
  5%|▌         | 53986/1000000 [00:00<00:15, 59511.66it/s][A
  6%|▌         | 59939/1000000 [00:01<00:15, 59513.22it/s][A
  7%|▋         | 65942/1000000 [00:01<00:15, 59669.77it/s][A
  7%|▋         | 71968/1000000 [00:01<00:15, 59845.94it/s][A
  8%|▊         | 77953/1000000 [00:01<00:15, 59717.49it/s][A
  8%|▊         | 83925/1000000 [00:01<00:15, 59543.78it/s][A
  9%|▉   

Total number of positive instances found in database : 361940
Total number of negative instances found in database : 638060



8001it [00:00, 267031.31it/s]
100%|██████████| 2/2 [01:46<00:00, 53.36s/it]
