In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
devices = ["cuda:0", "cuda:1"]
TARGET_ACCURACY = 0.8
INTERVAL_BATCH = 1

In [3]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [4]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape

rules_data = [
    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS}},
    # {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}},
    # {"name": "easy_5_7_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS}},
    #
    # {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
]

In [5]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet18_model(db_dir, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=target_accuracy,
                         interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [None]:
from tqdm import tqdm
from joblib import Parallel, delayed

for rule_idx in tqdm(range(0, len(rules_data), 2)):
    Parallel(n_jobs=len(devices))(delayed(_train_model)(
        db_dir, # db_dir
        rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
        os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
        TARGET_ACCURACY, # target_accuracy,
        INTERVAL_BATCH, # interval_batch
        devices[offset]) for offset in [0, 1])

  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset content for easy_1_6_blue_train.csv
Loading dataset content for easy_2_row_circle_train.csv


100%|██████████| 8000/8000 [01:03<00:00, 126.31it/s]
 95%|█████████▍| 7587/8000 [01:01<00:03, 126.40it/s]

Loading dataset content for easy_1_6_blue_valid.csv


100%|██████████| 8000/8000 [01:04<00:00, 123.89it/s]
  1%|          | 11/1000 [00:00<00:09, 105.17it/s]

Loading dataset content for easy_2_row_circle_valid.csv


100%|██████████| 1000/1000 [00:10<00:00, 97.78it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 74%|███████▍  | 740/1000 [00:07<00:02, 98.93it/s] 

EPOCH 1:


 92%|█████████▏| 915/1000 [00:09<00:00, 94.05it/s] 

LOSS train 0.7611 valid 8.3387
Accuracy cap NOT hit at Step 1 : 0.5 < 0.8


 97%|█████████▋| 974/1000 [00:11<00:00, 56.44it/s]

LOSS train 5.7109 valid 16.7289
Accuracy cap NOT hit at Step 2 : 0.5 < 0.8


100%|██████████| 1000/1000 [00:11<00:00, 84.77it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 10.1425 valid 14.2983
Accuracy cap NOT hit at Step 3 : 0.5 < 0.8
LOSS train 4.2308 valid 79.5724
Accuracy cap NOT hit at Step 4 : 0.5 < 0.8
LOSS train 0.7636 valid 8.5569
Accuracy cap NOT hit at Step 1 : 0.5 < 0.8
LOSS train 12.0602 valid 215.3828
Accuracy cap NOT hit at Step 5 : 0.5 < 0.8
LOSS train 4.5931 valid 5.9198
Accuracy cap NOT hit at Step 2 : 0.5 < 0.8
LOSS train 15.0057 valid 405.9255
Accuracy cap NOT hit at Step 6 : 0.5 < 0.8
LOSS train 2.3994 valid 90.2850
Accuracy cap NOT hit at Step 3 : 0.5 < 0.8
LOSS train 6.9014 valid 3850.1704
Accuracy cap NOT hit at Step 7 : 0.5 < 0.8
LOSS train 22.4140 valid 879.5223
Accuracy cap NOT hit at Step 4 : 0.5 < 0.8
LOSS train 11.0876 valid 8710.8361
Accuracy cap NOT hit at Step 8 : 0.5 < 0.8
LOSS train 23.6601 valid 19176.9807
Accuracy cap NOT hit at Step 5 : 0.5 < 0.8
LOSS train 16.5667 valid 11943.7954
Accuracy cap NOT hit at Step 9 : 0.5 < 0.8
LOSS train 15.5576 valid 55821.6498
Accuracy cap NOT hit at Step 6 : 0.5 