In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"
os.makedirs(db_dir, exist_ok=True)

test_datasets_sizes=100
valid_datasets_sizes=100
full_datasets_pos_samples_nb=1000
full_datasets_neg_samples_nb=1000
sample_nb_per_class = 100

In [2]:
# Number of images generated
NBGEN_full_per_size = 250000
NBGEN_patterns = 100

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [3]:
import numpy as np
from xaipatimg.datagen.dbimg import generate_uuid
import os

def generate_db(db, x_divisions, y_divisions, to_generate, img_size):
    unique_content_generated = {}

    duplicate_count = 0
    while to_generate > 0:
        content = []
        for i in range(x_divisions):
            for j in range(y_divisions):
                if np.random.random() < SHAPE_PROB:
                    content.append({
                        "shape": np.random.choice(SHAPES),
                        "pos": (i, j),
                        "color": np.random.choice(COLORS)
                    })

        if str(content) in unique_content_generated:
            duplicate_count += 1
            continue

        imgid = generate_uuid()
        db[imgid] = {
            "path": os.path.join("img", imgid + ".png"),
            "division" : (x_divisions, y_divisions),
            "size": img_size,
            "content": content
        }

        unique_content_generated[str(content)] = True
        to_generate -= 1

    print("unique generated in DB : " + str(len(db)))
    print("duplicates avoided : " + str(duplicate_count))
    return db

In [4]:
import tqdm

def check_for_duplicates(db):
    content_dict = {}
    nb_duplicates = 0

    for k, v in tqdm.tqdm(db.items()):
        if str(v["content"]) in content_dict:
            nb_duplicates += 1
        else:
            content_dict[str(v["content"])] = True

    print(nb_duplicates)

In [5]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)
db_patterns = load_db(db_patterns_dir)

#### Generate full DB

In [6]:

# db = generate_db(db, X_DIVISIONS_L, Y_DIVISIONS_L, NBGEN_full_per_size, img_size)
# db = generate_db(db, X_DIVISIONS_S, Y_DIVISIONS_S, NBGEN_full_per_size, img_size)

In [7]:
check_for_duplicates(db)

100%|██████████| 500000/500000 [00:45<00:00, 10871.40it/s]


0


In [8]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db, db_dir, overwrite=True, n_jobs=20)

#### Generate DB of patterns

In [9]:
# db_patterns = generate_db(db_patterns, X_DIVISIONS_PATTERNS, Y_DIVISIONS_PATTERNS, NBGEN_patterns, img_size_patterns)

In [10]:
check_for_duplicates(db_patterns)

100%|██████████| 100/100 [00:00<00:00, 155690.57it/s]

0





In [11]:
from xaipatimg.datagen.genimg import gen_img_and_save_db
# gen_img_and_save_db(db_patterns, db_patterns_dir, overwrite=True, draw_coordinates=False, n_jobs=20)

## Interface prototype v5

In [12]:
datasets_path = os.path.join(db_dir, "datasets", "01_protov5")

In [13]:
pattern_2sym_keys = []
pattern_3sym_keys = []
for k, v in db_patterns.items():
    if len(v["content"]) == 2:
        pattern_2sym_keys.append(k)
    if len(v["content"]) == 3:
        pattern_3sym_keys.append(k)


In [14]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more

rules_data = [
    # {"name": "easy1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[0]},
    #
    # {"name": "easy2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_3sym_keys[0]},

    {"name": "easy3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                         "consider_rotations": True},
     "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[4]},

    # {"name": "hard1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[2]},
    #
    # {"name": "hard2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_3sym_keys[1]},
    #
    # {"name": "hard3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[3]},

]

In [15]:
from xaipatimg.datagen.gendataset import create_dataset_generic_rule_extract_sample
import tqdm

for rule_line in tqdm.tqdm(rules_data):
    name = rule_line["name"]
    sample_path = os.path.join(datasets_path, f"{name}_train")
    rule_line["gen_kwargs"]["pattern_content"] = db_patterns[rule_line["pattern_id"]]["content"]
    create_dataset_generic_rule_extract_sample(db_dir, datasets_dir_path=datasets_path, csv_name_train=name+"_train.csv",
                                               csv_name_test=name+"_test.csv", csv_name_valid=name+"_valid.csv", test_size=test_datasets_sizes,
                                               valid_size=valid_datasets_sizes, dataset_pos_samples_nb=full_datasets_pos_samples_nb,
                                               dataset_neg_samples_nb=full_datasets_neg_samples_nb,
                                               sample_path=sample_path, sample_nb_per_class=sample_nb_per_class,
                                               generic_rule_fun=rule_line["gen_fun"], filter_on_dim=rule_line["filter_on_dim"],
                                               **rule_line["gen_kwargs"])

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/500000 [00:00<?, ?it/s][A
  0%|          | 85/500000 [00:00<09:51, 845.71it/s][A
  0%|          | 170/500000 [00:00<09:59, 834.29it/s][A
  0%|          | 261/500000 [00:00<09:36, 866.78it/s][A
  0%|          | 357/500000 [00:00<09:15, 899.24it/s][A
  0%|          | 447/500000 [00:00<10:06, 823.25it/s][A
  0%|          | 531/500000 [00:00<10:23, 801.71it/s][A
  0%|          | 612/500000 [00:00<10:51, 766.17it/s][A
  0%|          | 693/500000 [00:00<10:42, 776.59it/s][A
  0%|          | 772/500000 [00:00<11:03, 752.97it/s][A
  0%|          | 848/500000 [00:01<11:29, 724.00it/s][A
  0%|          | 931/500000 [00:01<11:03, 752.46it/s][A
  0%|          | 1017/500000 [00:01<10:38, 782.09it/s][A
  0%|          | 1104/500000 [00:01<10:18, 806.89it/s][A
  0%|          | 1188/500000 [00:01<10:12, 814.35it/s][A
  0%|          | 1270/500000 [00:01<10:33, 786.69it/s][A
  0%|          | 1350/500000 [00:01<11:50, 701.87it/s][A


Total number of positive instances found in database : 80040
Total number of negative instances found in database : 125158



1801it [00:00, 23641.83it/s]
100%|██████████| 1/1 [10:18<00:00, 618.35s/it]


In [16]:
db_patterns[pattern_2sym_keys[1]]

{'path': 'img/0ce79c12b31211f092e2d4d8534cb0f8.png',
 'division': [2, 2],
 'size': [300, 300],
 'content': [{'shape': 'square', 'pos': [0, 0], 'color': '#E0B000'},
  {'shape': 'square', 'pos': [0, 1], 'color': '#E0B000'}]}

In [17]:
 list(db_patterns.values())[0]

{'path': 'img/0ce78dc6b31211f092e2d4d8534cb0f8.png',
 'division': [2, 2],
 'size': [300, 300],
 'content': [{'shape': 'triangle', 'pos': [0, 1], 'color': '#A33E9A'},
  {'shape': 'square', 'pos': [1, 0], 'color': '#A33E9A'}]}

In [18]:
# b1, b2 = generic_rule_pattern_exactly_1_time_exclude_more(db["becf1daab30e11f092e2d4d8534cb0f8"],
#                                                  db_patterns["0ce78dc6b31211f092e2d4d8534cb0f8"],
#                                                  10, 10, 2, 2)

In [19]:
list(db_patterns.values())[pattern_2sym_keys[1]]

TypeError: list indices must be integers or slices, not str