In [14]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/M/"
db_XS_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/XS/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/patterns/"
os.makedirs(db_S_dir, exist_ok=True)
os.makedirs(db_L_dir, exist_ok=True)
os.makedirs(db_M_dir, exist_ok=True)
os.makedirs(db_patterns_dir, exist_ok=True)

test_datasets_sizes = 500
valid_datasets_sizes = 500
full_datasets_pos_samples_nb = 25000
full_datasets_neg_samples_nb = 25000
sample_nb_per_class = 100

N_JOBS = 20

model_dir_root = os.environ["DATA"] + "models/db3.2.0/01_expv1/"
devices = ["cuda:1", "cuda:0", "cuda:1", "cuda:0"]
INTERVAL_BATCH = 2
RESNET_TYPE = "resnet18"

In [15]:
# Number of images generated
NBGEN_full_per_size = 5000000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 8
Y_DIVISIONS_S = 8
X_DIVISIONS_M = 11
Y_DIVISIONS_M = 11
X_DIVISIONS_XS = 6
Y_DIVISIONS_XS = 6

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['c', 's', 't']
COLORS  = ["p", "y", "b"]

In [16]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [17]:
import numpy as np
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["cnt"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2,2), "", dtype="U100")
        for entry in v["cnt"]:
            img_col_d[entry["col"]] = True
            img_shape_d[entry["shp"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["col"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [18]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_expv1")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_expv1")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_expv1")
datasets_path_XS = os.path.join(db_XS_dir, "datasets", "01_expv1")


In [19]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},


    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

]

In [20]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                                   "y_division_full": Y_DIVISIONS_S,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                                   "y_division_full": Y_DIVISIONS_S,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

]

In [21]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "med5_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},


]


In [22]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_XS = [
    {"name": "xeasy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_XS,
                                                                                                                    "y_division_full": Y_DIVISIONS_XS,
                                                                                                                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                    "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.5, "shown_acc" : 0.5, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

]

In [23]:
from xaipatimg.ml.learning import train_resnet_model, compute_resnet_model_scores


def _train_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet_model(os.path.join(db_dir, "min/"), datasets_dir_path, train_dataset_filename, test_dataset_filename, model_dir,
                       target_accuracy=target_accuracy, interval_batch=interval_batch, device=device, training_epochs=50,
                       resnet_type=RESNET_TYPE)

    compute_resnet_model_scores(os.path.join(db_dir, "min/"), datasets_dir_path, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device, resnet_type=RESNET_TYPE)


In [24]:
from tqdm import tqdm
from joblib import Parallel, delayed

def train(rules_data, datasets_dir, db_dir):
    for rule_idx in tqdm(range(0, len(rules_data), 4)):

        if rule_idx + 1 == len(rules_data):
            offsets = [0]
        elif rule_idx + 2 == len(rules_data):
            offsets = [0, 1]
        elif rule_idx + 3 == len(rules_data):
            offsets = [0, 1, 2]
        else:
            offsets = [0, 1, 2, 3]

        Parallel(n_jobs=len(devices))(delayed(_train_model)(
            db_dir, # db_dir
            datasets_dir,
            rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
            rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
            rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
            os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
            rules_data[rule_idx + offset]["target_acc"], # target_accuracy,
            INTERVAL_BATCH if rules_data[rule_idx + offset]["target_acc"] < 1.0 else 50, # interval_batch (higher if the target accuracy is the best possible performance)
            devices[offset]) for offset in offsets)

In [None]:
# train(rules_data_S, datasets_path_S, db_S_dir)

In [None]:
train(rules_data_L, datasets_path_L, db_L_dir)

In [None]:
# train(rules_data_M, datasets_path_M, db_M_dir)

In [25]:
# train(rules_data_XS, datasets_path_XS, db_XS_dir)

100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
  3%|▎         | 13/500 [00:00<00:03, 123.34it/s]

Train dataset statistics : [0.9433884620666504, 0.9393646121025085, 0.9372832179069519] [0.16525255143642426, 0.15669815242290497, 0.17703242599964142]
Loading dataset content for xeasy1_find_pattern_rot_test.csv


100%|██████████| 500/500 [00:04<00:00, 122.60it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 0.6886 valid 4.6180
Accuracy cap hit at Step 2 : 0.5 >= 0.5
Training complete


Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 1/1 [00:37<00:00, 37.15s/it]

{'train': {'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'roc_auc': 0.613952, 'confusion matrix': {'TN': 0, 'FP': 0, 'FN': 500, 'TP': 500}}, 'test': {'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'roc_auc': 0.62136, 'confusion matrix': {'TN': 0, 'FP': 0, 'FN': 250, 'TP': 250}}, 'valid': {'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'roc_auc': 0.595128, 'confusion matrix': {'TN': 0, 'FP': 0, 'FN': 250, 'TP': 250}}}



