In [1]:
import os
import shutil

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/M/"
db_XS_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/XS/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 500
valid_datasets_sizes = 500
full_datasets_pos_samples_nb = 7500
full_datasets_neg_samples_nb = 7500
sample_nb_per_class = 100

N_JOBS = 20

In [2]:
# Number of images generated
NBGEN_full_per_size = 5000000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 8
Y_DIVISIONS_S = 8
X_DIVISIONS_M = 11
Y_DIVISIONS_M = 11
X_DIVISIONS_XS = 6
Y_DIVISIONS_XS = 6

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['c', 's', 't']
COLORS  = ["p", "y", "b"]

In [3]:
from xaipatimg.datagen.dbimg import load_db
db_patterns = load_db(db_patterns_dir)


In [4]:
import numpy as np
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["cnt"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2,2), "", dtype="U100")
        for entry in v["cnt"]:
            img_col_d[entry["col"]] = True
            img_shape_d[entry["shp"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["col"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [5]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},


    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

]

In [6]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                                   "y_division_full": Y_DIVISIONS_S,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                                   "y_division_full": Y_DIVISIONS_S,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

]

In [7]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "med5_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},


]


In [8]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_XS = [
    {"name": "xeasy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_XS,
                                                                                                                    "y_division_full": Y_DIVISIONS_XS,
                                                                                                                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                    "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.5, "shown_acc" : 0.5, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

]

In [9]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_expv1")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_expv1")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_expv1")
datasets_path_XS = os.path.join(db_XS_dir, "datasets", "01_expv1")


In [10]:
from xaipatimg.datagen.jsondb import JSONDB
from os import makedirs
import pandas as pd
import tqdm
from pathlib import Path



def _get_all_img_ids_from_db_csv(dataset_csv):
    ids_list = []
    for path in dataset_csv["path"]:
        ids_list.append(Path(path).stem)
    return ids_list

def reduce_db(db_dir, datasets_path, dataset_names):
    """
    Saving a minimal version of the DB that contains all data for the datasets given as a list.
    :param db_dir: root directory of the datasets
    :param dataset_names: names of the datasets to extract the data for
    :return:
    """

    min_db_dir = os.path.join(db_dir, "min")
    makedirs(min_db_dir, exist_ok=True)
    makedirs(os.path.join(min_db_dir, "img"), exist_ok=True)

    all_keys_needed = {}

    # Iterating over datasets
    for dataset_name in dataset_names:

        ids_list = []

        ids_list.extend(_get_all_img_ids_from_db_csv(pd.read_csv(os.path.join(datasets_path, dataset_name + "_train.csv"))))
        ids_list.extend(_get_all_img_ids_from_db_csv(pd.read_csv(os.path.join(datasets_path, dataset_name + "_test.csv"))))
        ids_list.extend(_get_all_img_ids_from_db_csv(pd.read_csv(os.path.join(datasets_path, dataset_name + "_valid.csv"))))

        for img_id in ids_list:
            all_keys_needed[img_id] = True

    full_db = JSONDB(os.path.join(db_dir, "db.json"))
    reduced_db = JSONDB(os.path.join(min_db_dir, "db.json"))

    for k, v in tqdm.tqdm(full_db.items()):
        if k in all_keys_needed:
            reduced_db[k] = v
            # shutil.copy(os.path.join(db_dir, "img", f"{k}.png"), os.path.join(min_db_dir, "img", f"{k}.png"))

    reduced_db.flush()


In [11]:
reduce_db(db_S_dir, datasets_path_S, [rule_data["name"] for rule_data in rules_data_S])

loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/S/db.json


2317500343it [11:36, 3325054.05it/s]


loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/S/min/db.json


5000000it [06:21, 13110.21it/s]
1830897it [00:00, 3177470.21it/s]


In [12]:
reduce_db(db_L_dir, datasets_path_L, [rule_data["name"] for rule_data in rules_data_L])

loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/L/db.json


6277437797it [32:44, 3196196.21it/s]


loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/L/min/db.json


5000000it [15:47, 5275.33it/s]
3583805it [00:01, 3201343.20it/s]


In [13]:
reduce_db(db_M_dir, datasets_path_M, [rule_data["name"] for rule_data in rules_data_M])

loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/M/db.json


4050080423it [20:47, 3247633.38it/s]


loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/M/min/db.json


5000000it [10:42, 7787.62it/s]
2374438it [00:00, 3201027.89it/s]


In [14]:
reduce_db(db_XS_dir, datasets_path_XS, [rule_data["name"] for rule_data in rules_data_XS])

loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/XS/db.json


23317995it [00:07, 3311532.00it/s]


loading keys for /home/docker/data/PatImgXAI_data/db3.2.0/XS/min/db.json


150000it [00:03, 44171.23it/s]
328041it [00:00, 2902983.11it/s]
