In [22]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0_S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0_L/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)

test_datasets_sizes = 1000
valid_datasets_sizes = 1000
full_datasets_pos_samples_nb = 5000
full_datasets_neg_samples_nb = 5000
sample_nb_per_class = 100

In [23]:
# Number of images generated
NBGEN_full_per_size = 400000
NBGEN_patterns = 100

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [24]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

def generate_db(db, x_divisions, y_divisions, to_generate, img_size):
    unique_content_generated = {}

    duplicate_count = 0
    while to_generate > 0:
        content = []
        for i in range(x_divisions):
            for j in range(y_divisions):
                if np.random.random() < SHAPE_PROB:
                    content.append({
                        "shape": np.random.choice(SHAPES),
                        "pos": (i, j),
                        "color": np.random.choice(COLORS)
                    })

        if str(content) in unique_content_generated:
            duplicate_count += 1
            continue

        imgid = generate_uuid()
        db[imgid] = {
            "path": os.path.join("img", imgid + ".png"),
            "division" : (x_divisions, y_divisions),
            "size": img_size,
            "content": content
        }

        unique_content_generated[str(content)] = True
        to_generate -= 1

    print("unique generated in DB : " + str(len(db)))
    print("duplicates avoided : " + str(duplicate_count))
    return db

In [25]:
import tqdm

def check_for_duplicates(db):
    content_dict = {}
    nb_duplicates = 0

    for k, v in tqdm.tqdm(db.items()):
        if str(v["content"]) in content_dict:
            nb_duplicates += 1
        else:
            content_dict[str(v["content"])] = True

    print(nb_duplicates)

In [26]:
from xaipatimg.datagen.dbimg import load_db


In [27]:
from xaipatimg.datagen.dbimg import load_db
#
# db_S = load_db(db_S_dir)
# db_L = load_db(db_L_dir)

In [28]:

db_patterns = load_db(db_patterns_dir)


#### Generate full DB

In [29]:

# db_L = generate_db(db_L, X_DIVISIONS_L, Y_DIVISIONS_L, NBGEN_full_per_size, img_size)

In [30]:
# check_for_duplicates(db_L)

In [31]:
# check_for_duplicates(db_S)

In [32]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db_L, db_L_dir, overwrite=True, n_jobs=20)
db_L = None

In [33]:
# db_S = generate_db(db_S, X_DIVISIONS_S, Y_DIVISIONS_S, NBGEN_full_per_size, img_size)


In [34]:
# check_for_duplicates(db_S)

In [35]:
# gen_img_and_save_db(db_S, db_S_dir, overwrite=True, n_jobs=20)
db_S = None

#### Generate DB of patterns

In [38]:
# db_patterns = generate_db(db_patterns, X_DIVISIONS_PATTERNS, Y_DIVISIONS_PATTERNS, NBGEN_patterns, img_size_patterns)

In [39]:
# check_for_duplicates(db_patterns)

In [40]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db_patterns, db_patterns_dir, overwrite=True, draw_coordinates=False, n_jobs=20)

## Interface prototype v5

In [41]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")

In [42]:
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")


In [43]:
pattern_2sym_keys = []
pattern_3sym_keys = []
for k, v in db_patterns.items():
    if len(v["content"]) == 2:
        pattern_2sym_keys.append(k)
    if len(v["content"]) == 3:
        pattern_3sym_keys.append(k)


In [44]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

def generate_all_datasets(rules_data, db_dir, datasets_path):
    for rule_line in tqdm.tqdm(rules_data):
        name = rule_line["name"]
        sample_path = os.path.join(datasets_path, f"{name}_train")
        rule_line["gen_kwargs"]["pattern_content"] = db_patterns[rule_line["pattern_id"]]["content"]
        create_dataset_generic_rule_extract_sample(db_dir, datasets_dir_path=datasets_path, csv_name_train=name+"_train.csv",
                                                   csv_name_test=name+"_test.csv", csv_name_valid=name+"_valid.csv", test_size=test_datasets_sizes,
                                                   valid_size=valid_datasets_sizes, dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                                   dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                                   sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                                   generic_rule_fun=rule_line["gen_fun"], filter_on_dim=rule_line["filter_on_dim"],
                                                   **rule_line["gen_kwargs"])

## DB L datasets generation

In [45]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_pattern_exactly_N_times

rules_data_L = [

    # {"name": "hard1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[3]},
    #
    # {"name": "hard2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_3sym_keys[1]},
    #
    # {"name": "hard3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "w": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[4]},

    {"name": "hard4_2sym_3times", "gen_fun": generic_rule_pattern_exactly_N_times, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "N": 3},
     "question": "Is the pattern exactly 3 times in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[5]},

]

In [25]:
generate_all_datasets(rules_data_L, db_L_dir, datasets_path_L)

  0%|          | 0/1 [00:00<?, ?it/s]
 17%|█▋        | 69643/400000 [03:32<16:53, 325.91it/s][A
 17%|█▋        | 69676/400000 [03:32<16:52, 326.25it/s][A
 17%|█▋        | 69709/400000 [03:32<16:54, 325.68it/s][A
 17%|█▋        | 69742/400000 [03:33<16:55, 325.20it/s][A
 17%|█▋        | 69775/400000 [03:33<16:56, 324.82it/s][A
 17%|█▋        | 69808/400000 [03:33<16:59, 323.81it/s][A
 17%|█▋        | 69845/400000 [03:33<16:25, 334.94it/s][A
 17%|█▋        | 69879/400000 [03:33<16:38, 330.45it/s][A
 17%|█▋        | 69913/400000 [03:33<16:47, 327.61it/s][A
 17%|█▋        | 69950/400000 [03:33<16:14, 338.66it/s][A
 17%|█▋        | 69984/400000 [03:33<16:30, 333.08it/s][A
 18%|█▊        | 70018/400000 [03:33<16:41, 329.51it/s][A
 18%|█▊        | 70051/400000 [03:33<16:46, 327.97it/s][A
 18%|█▊        | 70084/400000 [03:34<16:55, 324.85it/s][A
 18%|█▊        | 70117/400000 [03:34<17:01, 322.83it/s][A
 18%|█▊        | 70153/400000 [03:34<16:29, 333.45it/s][A
 18%|█▊        | 7

Total number of positive instances found in database : 7847
Total number of negative instances found in database : 392153



0it [00:00, ?it/s][A
26it [00:00, 255.37it/s][A
52it [00:00, 248.26it/s][A
77it [00:00, 225.69it/s][A
100it [00:00, 226.52it/s][A
129it [00:00, 247.78it/s][A
154it [00:00, 236.32it/s][A
8001it [00:00, 9279.80it/s]A
100%|██████████| 1/1 [21:40<00:00, 1300.68s/it]


In [46]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_pattern_exactly_N_times

rules_data_S = [
    # {"name": "easy1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[0]},
    #
    # {"name": "easy2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_3sym_keys[0]},
    #
    # {"name": "easy3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                      "consider_rotations": True},
    #  "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[1]},


    {"name": "easy4_2sym_2times", "gen_fun": generic_rule_pattern_exactly_N_times, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "N": 2},
     "question": "Is the pattern exactly 2 times in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[2]},

]

In [47]:
generate_all_datasets(rules_data_S, db_S_dir, datasets_path_S)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/400000 [00:00<?, ?it/s][A
  0%|          | 85/400000 [00:00<07:53, 843.89it/s][A
  0%|          | 170/400000 [00:00<08:20, 798.10it/s][A
  0%|          | 250/400000 [00:00<08:31, 781.19it/s][A
  0%|          | 329/400000 [00:00<08:42, 764.90it/s][A
  0%|          | 412/400000 [00:00<08:27, 786.84it/s][A
  0%|          | 493/400000 [00:00<08:23, 793.16it/s][A
  0%|          | 580/400000 [00:00<08:09, 815.90it/s][A
  0%|          | 668/400000 [00:00<07:59, 833.56it/s][A
  0%|          | 752/400000 [00:00<08:11, 812.91it/s][A
  0%|          | 839/400000 [00:01<08:02, 827.25it/s][A
  0%|          | 922/400000 [00:01<08:14, 807.36it/s][A
  0%|          | 1003/400000 [00:01<08:21, 795.54it/s][A
  0%|          | 1083/400000 [00:01<08:27, 786.52it/s][A
  0%|          | 1170/400000 [00:01<08:12, 810.37it/s][A
  0%|          | 1252/400000 [00:01<08:19, 797.97it/s][A
  0%|          | 1332/400000 [00:01<08:25, 788.58it/s][A


Total number of positive instances found in database : 9450
Total number of negative instances found in database : 390550



0it [00:00, ?it/s][A
58it [00:00, 578.37it/s][A
116it [00:00, 572.12it/s][A
8001it [00:00, 25712.05it/s]
100%|██████████| 1/1 [08:44<00:00, 524.67s/it]
