In [7]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
devices = ["cuda:0", "cuda:1"]
INTERVAL_BATCH = 1

In [8]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [9]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape, \
    generic_rule_shape_in_every_row

rules_data = [
    {"name": "disc_1_triangle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0, "samples_interface": 5},
    {"name": "disc_1_square_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS}, "question": "In the image, is there a square in every row (1, ..., 6)?", "target_acc" : 1.0, "samples_interface": 5},
    {"name": "disc_1_circle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a circle in every row (1, ..., 6)?", "target_acc" : 1.0, "samples_interface": 5},

    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 6 blue symbols?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only circles?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 purple symbols?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "easy_5_7_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 yellow symbols?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only squares?", "target_acc": 0.8, "samples_interface": 10},

    {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only purple symbols, and one column (A, ..., F) containing only triangles?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only yellow symbols, and one column (A, ..., F) containing only circles?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?", "target_acc": 0.8, "samples_interface": 10},
    {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only blue symbols, and one column (A, ..., F) containing only squares?", "target_acc": 0.8, "samples_interface": 10},
]

In [10]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet18_model(db_dir, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=target_accuracy,
                         interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [11]:
from tqdm import tqdm
from joblib import Parallel, delayed

for rule_idx in tqdm(range(0, len(rules_data), 2)):

    if rule_idx + 1 == len(rules_data):
        offsets = [0]
    else:
        offsets = [0, 1]

    Parallel(n_jobs=len(devices))(delayed(_train_model)(
        db_dir, # db_dir
        rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
        os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
        rules_data[rule_idx + offset]["target_acc"], # target_accuracy,
        INTERVAL_BATCH if rules_data[rule_idx + offset]["target_acc"] < 1.0 else 50, # interval_batch (higher if the target accuracy is the best possible performance)
        devices[offset]) for offset in offsets)

  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset content for disc_1_circle_all_train.csv


100%|██████████| 8000/8000 [02:18<00:00, 57.91it/s]


Train dataset statistics : [0.9429671168327332, 0.9412433505058289, 0.9376640915870667] [0.16828131675720215, 0.1518343985080719, 0.17763590812683105]
Loading dataset content for disc_1_circle_all_train.csv


100%|██████████| 8000/8000 [01:04<00:00, 124.29it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]

Loading dataset content for disc_1_circle_all_valid.csv


100%|██████████| 1000/1000 [00:13<00:00, 74.29it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 2.3982 valid 0.9306
Accuracy cap NOT hit at Step 50 : 0.5 < 1.0
LOSS train 0.7090 valid 0.6971
Accuracy cap NOT hit at Step 100 : 0.5 < 1.0
LOSS train 0.7049 valid 0.7165
Accuracy cap NOT hit at Step 150 : 0.5 < 1.0
LOSS train 0.6973 valid 0.6984
Accuracy cap NOT hit at Step 200 : 0.504 < 1.0
LOSS train 0.6962 valid 0.6901
Accuracy cap NOT hit at Step 250 : 0.532 < 1.0
EPOCH 2:
LOSS train 0.6892 valid 0.6821
Accuracy cap NOT hit at Step 300 : 0.533 < 1.0
LOSS train 0.6561 valid 0.8925
Accuracy cap NOT hit at Step 350 : 0.519 < 1.0
LOSS train 0.5256 valid 2.0231
Accuracy cap NOT hit at Step 400 : 0.534 < 1.0
LOSS train 0.4288 valid 1.2687
Accuracy cap NOT hit at Step 450 : 0.595 < 1.0
LOSS train 0.4250 valid 1.3303
Accuracy cap NOT hit at Step 500 : 0.592 < 1.0
EPOCH 3:
LOSS train 0.3092 valid 0.8464
Accuracy cap NOT hit at Step 550 : 0.753 < 1.0
LOSS train 0.1918 valid 0.4410
Accuracy cap NOT hit at Step 600 : 0.809 < 1.0
LOSS train 0.1318 valid 0.7812
Accuracy cap 

100%|██████████| 8000/8000 [01:02<00:00, 127.22it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]

Loading dataset content for disc_1_circle_all_test.csv


100%|██████████| 1000/1000 [00:12<00:00, 77.97it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]

Loading dataset content for disc_1_circle_all_valid.csv


100%|██████████| 1000/1000 [00:07<00:00, 128.27it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 1/1 [06:59<00:00, 419.92s/it]

{'train': {'accuracy': 0.99975, 'precision': 0.9995002498750625, 'recall': 1.0, 'roc_auc': 1.0, 'confusion matrix': {'TN': 3998, 'FP': 0, 'FN': 2, 'TP': 4000}}, 'test': {'accuracy': 0.999, 'precision': 0.998003992015968, 'recall': 1.0, 'roc_auc': 1.0, 'confusion matrix': {'TN': 499, 'FP': 0, 'FN': 1, 'TP': 500}}, 'valid': {'accuracy': 1.0, 'precision': 1.0, 'recall': 1.0, 'roc_auc': 1.0, 'confusion matrix': {'TN': 500, 'FP': 0, 'FN': 0, 'TP': 500}}}



