In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
devices = ["cuda:0", "cuda:1"]
TARGET_ACCURACY = 0.8
INTERVAL_BATCH = 1

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape

rules_data = [
    {"name": "easy_1_row_circles", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS}},
    {"name": "easy_1_row_triangles", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}}
]

In [6]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    # train_resnet18_model(db_dir, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=target_accuracy,
    #                      interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [8]:
from tqdm import tqdm
from joblib import Parallel, delayed

for rule_idx in tqdm(range(0, len(rules_data), 2)):
    Parallel(n_jobs=len(devices))(delayed(_train_model)(
        db_dir, # db_dir
        rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
        os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
        TARGET_ACCURACY, # target_accuracy,
        INTERVAL_BATCH, # interval_batch
        devices[offset]) for offset in [0, 1])

  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset content for easy_1_row_triangles_train.csv
Loading dataset content for easy_1_row_circles_train.csv


100%|██████████| 8000/8000 [01:04<00:00, 123.95it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]

Loading dataset content for easy_1_row_triangles_test.csv
Loading dataset content for easy_1_row_circles_test.csv


100%|██████████| 8000/8000 [01:04<00:00, 123.79it/s]
100%|██████████| 1000/1000 [00:07<00:00, 126.21it/s]


Loading dataset content for easy_1_row_triangles_valid.csv
Loading dataset content for easy_1_row_circles_valid.csv


100%|██████████| 1000/1000 [00:07<00:00, 126.38it/s]
100%|██████████| 1000/1000 [00:08<00:00, 124.35it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 1000/1000 [00:07<00:00, 125.50it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


{'train': {'accuracy': 0.813125, 'precision': 0.8631487387648594, 'recall': 0.74425, 'roc_auc': 0.9034872500000001, 'confusion matrix': {'TN': 3528, 'FP': 1023, 'FN': 472, 'TP': 2977}}, 'test': {'accuracy': 0.794, 'precision': 0.8466981132075472, 'recall': 0.718, 'roc_auc': 0.897844, 'confusion matrix': {'TN': 435, 'FP': 141, 'FN': 65, 'TP': 359}}, 'valid': {'accuracy': 0.82, 'precision': 0.8587443946188341, 'recall': 0.766, 'roc_auc': 0.9051520000000001, 'confusion matrix': {'TN': 437, 'FP': 117, 'FN': 63, 'TP': 383}}}


100%|██████████| 1/1 [01:39<00:00, 99.27s/it]

{'train': {'accuracy': 0.821, 'precision': 0.8900364520048603, 'recall': 0.7325, 'roc_auc': 0.9274926874999999, 'confusion matrix': {'TN': 3638, 'FP': 1070, 'FN': 362, 'TP': 2930}}, 'test': {'accuracy': 0.82, 'precision': 0.8940886699507389, 'recall': 0.726, 'roc_auc': 0.9300279999999999, 'confusion matrix': {'TN': 457, 'FP': 137, 'FN': 43, 'TP': 363}}, 'valid': {'accuracy': 0.801, 'precision': 0.8679706601466992, 'recall': 0.71, 'roc_auc': 0.915, 'confusion matrix': {'TN': 446, 'FP': 145, 'FN': 54, 'TP': 355}}}



