In [1]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"

model_dir_root = os.environ["DATA"] + "models/db3.0.0/01_protov5/"
devices = ["cuda:0", "cuda:1"]
INTERVAL_BATCH = 100

In [2]:
# Number of images generated
NBGEN_full_per_size = 200000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'cross']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [3]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [4]:
pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors
for k, v in db_patterns.items():
    if len(v["content"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        for entry in v["content"]:
            img_col_d[entry["color"]] = True
            img_shape_d[entry["shape"]] = True

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3:
            pattern_3sym_2col_keys.append(k)

In [5]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [6]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_shape_color_even

rules_data_L = [

    # {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},

    {"name": "hard2_blue_circle_even", "gen_fun": generic_rule_shape_color_even, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0"},
     "question": "Is the number of blue circles an even number?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    {"name": "hard4_purple_square_even", "gen_fun": generic_rule_shape_color_even, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A"},
     "question": "Is the number of purple squares an even number?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [7]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_shape_color_even

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "easy2_yellow_cross_even", "gen_fun": generic_rule_shape_color_even, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "cross",
                                                                                                "color": "#E0B000"},
     "question": "Is the number of yellow crosses an even number?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    {"name": "easy4_purple_circle_even", "gen_fun": generic_rule_shape_color_even, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A"},
     "question": "Is the number of purple circles an even number?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [8]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_shape_color_even

rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "med2_yellow_square_even", "gen_fun": generic_rule_shape_color_even, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000"},
     "question": "Is the number of yellow squares an even number?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

    {"name": "med4_blue_cross_even", "gen_fun": generic_rule_shape_color_even, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "cross",
                                                                                                "color": "#0C90C0"},
     "question": "Is the number of blue crosses an even number?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]

In [9]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet18_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=target_accuracy,
                         interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, datasets_dir_path, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [10]:
from tqdm import tqdm
from joblib import Parallel, delayed

def train(rules_data, datasets_dir, db_dir):
    for rule_idx in tqdm(range(0, len(rules_data), 2)):

        if rule_idx + 1 == len(rules_data):
            offsets = [0]
        else:
            offsets = [0, 1]

        Parallel(n_jobs=len(devices))(delayed(_train_model)(
            db_dir, # db_dir
            datasets_dir,
            rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
            rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
            rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
            os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
            rules_data[rule_idx + offset]["target_acc"], # target_accuracy,
            INTERVAL_BATCH if rules_data[rule_idx + offset]["target_acc"] < 1.0 else 50, # interval_batch (higher if the target accuracy is the best possible performance)
            devices[offset]) for offset in offsets)

In [None]:
train(rules_data_S, datasets_path_S, db_S_dir)

  0%|          | 0/2 [00:00<?, ?it/s]

Loading dataset content for easy1_find_pattern_rot_train.csv
Loading dataset content for easy2_yellow_cross_even_train.csv


100%|██████████| 8000/8000 [01:07<00:00, 119.36it/s]
100%|██████████| 8000/8000 [01:07<00:00, 119.16it/s]


Train dataset statistics : [0.9257279634475708, 0.9236022233963013, 0.919021487236023] [0.17885220050811768, 0.16124631464481354, 0.18928486108779907]
Loading dataset content for easy2_yellow_cross_even_train.csv


  0%|          | 12/8000 [00:00<01:08, 115.96it/s]]

Train dataset statistics : [0.9257009625434875, 0.9227545261383057, 0.9181486964225769] [0.17811931669712067, 0.1626540720462799, 0.1905139833688736]
Loading dataset content for easy1_find_pattern_rot_train.csv


100%|██████████| 8000/8000 [01:06<00:00, 121.14it/s]
  1%|▏         | 13/1000 [00:00<00:08, 121.50it/s]s]

Loading dataset content for easy2_yellow_cross_even_valid.csv


100%|██████████| 8000/8000 [01:07<00:00, 119.02it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]

Loading dataset content for easy1_find_pattern_rot_valid.csv


100%|██████████| 1000/1000 [00:08<00:00, 122.35it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 86%|████████▌ | 857/1000 [00:07<00:01, 120.59it/s]

EPOCH 1:


100%|██████████| 1000/1000 [00:08<00:00, 120.56it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 2.2556 valid 8.7079
Accuracy cap hit at Step 100 : 0.5 >= 0.1
Training complete
Loading dataset content for easy2_yellow_cross_even_train.csv


  3%|▎         | 221/8000 [00:01<01:02, 124.61it/s]

LOSS train 2.6317 valid 0.7388


  3%|▎         | 247/8000 [00:02<01:02, 124.26it/s]

Accuracy cap hit at Step 100 : 0.506 >= 0.1


  0%|          | 0/8000 [00:00<?, ?it/s]123.62it/s]

Training complete
Loading dataset content for easy1_find_pattern_rot_train.csv


100%|██████████| 8000/8000 [01:05<00:00, 121.94it/s]
  1%|▏         | 13/1000 [00:00<00:08, 122.30it/s]s]

Loading dataset content for easy2_yellow_cross_even_test.csv


100%|██████████| 8000/8000 [01:06<00:00, 121.14it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]

Loading dataset content for easy1_find_pattern_rot_test.csv


100%|██████████| 1000/1000 [00:08<00:00, 122.62it/s]
 68%|██████▊   | 676/1000 [00:05<00:02, 123.48it/s]

Loading dataset content for easy2_yellow_cross_even_valid.csv


100%|██████████| 1000/1000 [00:08<00:00, 123.32it/s]
  1%|▏         | 13/1000 [00:00<00:07, 123.75it/s]]

Loading dataset content for easy1_find_pattern_rot_valid.csv


100%|██████████| 1000/1000 [00:08<00:00, 123.10it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 1000/1000 [00:08<00:00, 122.97it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'train': {'accuracy': 0.5, 'precision': 0.0, 'recall': 0.0, 'roc_auc': 0.501807625, 'confusion matrix': {'TN': 4000, 'FP': 4000, 'FN': 0, 'TP': 0}}, 'test': {'accuracy': 0.5, 'precision': 0.0, 'recall': 0.0, 'roc_auc': 0.516904, 'confusion matrix': {'TN': 500, 'FP': 500, 'FN': 0, 'TP': 0}}, 'valid': {'accuracy': 0.5, 'precision': 0.0, 'recall': 0.0, 'roc_auc': 0.495904, 'confusion matrix': {'TN': 500, 'FP': 500, 'FN': 0, 'TP': 0}}}


 50%|█████     | 1/2 [06:57<06:57, 417.41s/it]

{'train': {'accuracy': 0.498875, 'precision': 0.4994275537463427, 'recall': 0.9815, 'roc_auc': 0.53727515625, 'confusion matrix': {'TN': 65, 'FP': 74, 'FN': 3935, 'TP': 3926}}, 'test': {'accuracy': 0.506, 'precision': 0.503061224489796, 'recall': 0.986, 'roc_auc': 0.5552680000000001, 'confusion matrix': {'TN': 13, 'FP': 7, 'FN': 487, 'TP': 493}}, 'valid': {'accuracy': 0.506, 'precision': 0.5030549898167006, 'recall': 0.988, 'roc_auc': 0.54588, 'confusion matrix': {'TN': 12, 'FP': 6, 'FN': 488, 'TP': 494}}}
Loading dataset content for easy3_find_pattern_rot_train.csv
Loading dataset content for easy4_purple_circle_even_train.csv


 22%|██▏       | 1778/8000 [00:14<00:50, 123.88it/s]


In [None]:
train(rules_data_L, datasets_path_L, db_L_dir)

In [None]:
train(rules_data_M, datasets_path_M, db_M_dir)