In [2]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"

model_dir_root = os.environ["DATA"] + "models/db3.0.0/01_protov5/"
datasets_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/datasets/01_protov5/"
devices = ["cuda:0"]
INTERVAL_BATCH = 1

In [3]:
# Number of images generated
NBGEN_full_per_size = 250000
NBGEN_patterns = 100

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [4]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [5]:
pattern_2sym_keys = []
pattern_3sym_keys = []
for k, v in db_patterns.items():
    if len(v["content"]) == 2:
        pattern_2sym_keys.append(k)
    if len(v["content"]) == 3:
        pattern_3sym_keys.append(k)


In [6]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more

rules_data = [
    # {"name": "easy1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[0]},
    #
    # {"name": "easy2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_3sym_keys[0]},

    {"name": "easy3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                         "consider_rotations": True},
     "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[4]},

    # {"name": "hard1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[2]},
    #
    # {"name": "hard2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_3sym_keys[1]},
    #
    # {"name": "hard3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[3]},

]

In [7]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet18_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=0.01,
                         interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, datasets_dir_path, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [8]:
from tqdm import tqdm
from joblib import Parallel, delayed

for rule_idx in tqdm(range(0, len(rules_data), 1)):

    if rule_idx + 1 == len(rules_data):
        offsets = [0]
    else:
        offsets = [0]

    Parallel(n_jobs=len(devices))(delayed(_train_model)(
        db_dir, # db_dir
        datasets_dir,
        rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
        os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
        rules_data[rule_idx + offset]["target_acc"], # target_accuracy,
        INTERVAL_BATCH if rules_data[rule_idx + offset]["target_acc"] < 1.0 else 50, # interval_batch (higher if the target accuracy is the best possible performance)
        devices[offset]) for offset in offsets)

  0%|          | 0/6 [00:00<?, ?it/s]

Loading dataset content for easy1_2sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 17/1800 [00:00<00:11, 161.81it/s][A
  2%|▏         | 34/1800 [00:00<00:11, 154.43it/s][A
  3%|▎         | 50/1800 [00:00<00:11, 154.58it/s][A
  4%|▎         | 66/1800 [00:00<00:17, 101.59it/s][A
  5%|▍         | 82/1800 [00:00<00:14, 115.69it/s][A
  5%|▌         | 96/1800 [00:00<00:19, 88.38it/s] [A
  6%|▌         | 112/1800 [00:01<00:16, 103.17it/s][A
  7%|▋         | 128/1800 [00:01<00:14, 115.64it/s][A
  8%|▊         | 143/1800 [00:01<00:13, 123.82it/s][A
  9%|▉         | 159/1800 [00:01<00:12, 131.40it/s][A
 10%|▉         | 175/1800 [00:01<00:11, 137.22it/s][A
 11%|█         | 191/1800 [00:01<00:11, 141.90it/s][A
 12%|█▏        | 207/1800 [00:01<00:11, 144.55it/s][A
 12%|█▏        | 222/1800 [00:01<00:10, 145.93it/s][A
 13%|█▎        | 238/1800 [00:01<00:10, 147.65it/s][A
 14%|█▍        | 253/1800 [00:02<00:14, 105.01it/s][A
 15%|█▍        | 267/1800 [00:02<00:13, 112.33it/s][A
 16%|█▌        | 280/1800 

Train dataset statistics : [0.9334152340888977, 0.9315114617347717, 0.9279568791389465] [0.16690947115421295, 0.151852548122406, 0.1761833131313324]
Loading dataset content for easy1_2sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 19/1800 [00:00<00:09, 187.78it/s][A
  2%|▏         | 38/1800 [00:00<00:13, 133.65it/s][A
  3%|▎         | 53/1800 [00:00<00:14, 123.31it/s][A
  4%|▎         | 66/1800 [00:00<00:14, 117.46it/s][A
  4%|▍         | 78/1800 [00:00<00:23, 73.42it/s] [A
  5%|▍         | 88/1800 [00:00<00:23, 71.40it/s][A
  6%|▌         | 104/1800 [00:01<00:18, 89.94it/s][A
  6%|▋         | 115/1800 [00:01<00:19, 88.63it/s][A
  7%|▋         | 131/1800 [00:01<00:15, 105.23it/s][A
  8%|▊         | 145/1800 [00:01<00:14, 113.69it/s][A
  9%|▉         | 160/1800 [00:01<00:13, 121.72it/s][A
 10%|▉         | 174/1800 [00:01<00:13, 123.08it/s][A
 11%|█         | 190/1800 [00:01<00:12, 129.90it/s][A
 11%|█▏        | 204/1800 [00:01<00:15, 102.77it/s][A
 12%|█▏        | 219/1800 [00:02<00:13, 113.45it/s][A
 13%|█▎        | 236/1800 [00:02<00:12, 125.99it/s][A
 14%|█▍        | 250/1800 [00:02<00:16, 95.13it/s] [A
 15%|█▍        | 266/1800 [00

Loading dataset content for easy1_2sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  4%|▍         | 4/100 [00:00<00:02, 39.88it/s][A
  8%|▊         | 8/100 [00:00<00:02, 36.26it/s][A
 12%|█▏        | 12/100 [00:00<00:02, 29.38it/s][A
 17%|█▋        | 17/100 [00:00<00:02, 35.14it/s][A
 21%|██        | 21/100 [00:00<00:02, 36.53it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 70.98it/s][A
 52%|█████▏    | 52/100 [00:00<00:00, 97.01it/s][A
 68%|██████▊   | 68/100 [00:00<00:00, 114.72it/s][A
 84%|████████▍ | 84/100 [00:01<00:00, 126.49it/s][A
100%|██████████| 100/100 [00:01<00:00, 89.83it/s] [A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 3.3021 valid 0.8452
Accuracy cap hit at Step 50 : 0.58 >= 0.01
Training complete
Loading dataset content for easy1_2sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 17/1800 [00:00<00:10, 163.99it/s][A
  2%|▏         | 34/1800 [00:00<00:10, 166.06it/s][A
  3%|▎         | 51/1800 [00:00<00:10, 166.70it/s][A
  4%|▍         | 68/1800 [00:00<00:10, 165.59it/s][A
  5%|▍         | 85/1800 [00:00<00:13, 123.26it/s][A
  6%|▌         | 99/1800 [00:00<00:20, 83.67it/s] [A
  6%|▋         | 116/1800 [00:01<00:16, 100.86it/s][A
  7%|▋         | 133/1800 [00:01<00:14, 115.81it/s][A
  8%|▊         | 150/1800 [00:01<00:12, 128.15it/s][A
  9%|▉         | 167/1800 [00:01<00:11, 137.87it/s][A
 10%|█         | 183/1800 [00:01<00:18, 85.29it/s] [A
 11%|█         | 200/1800 [00:01<00:15, 100.57it/s][A
 12%|█▏        | 217/1800 [00:01<00:13, 114.32it/s][A
 13%|█▎        | 234/1800 [00:01<00:12, 125.41it/s][A
 14%|█▍        | 251/1800 [00:02<00:11, 135.71it/s][A
 15%|█▍        | 268/1800 [00:02<00:10, 144.17it/s][A
 16%|█▌        | 285/1800 [00:02<00:10, 150.66it/s][A
 17%|█▋        | 302/1800 

Loading dataset content for easy1_2sym_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 18%|█▊        | 18/100 [00:00<00:00, 177.28it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 175.14it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 179.31it/s][A
 73%|███████▎  | 73/100 [00:00<00:00, 179.49it/s][A
100%|██████████| 100/100 [00:00<00:00, 179.14it/s][A


Loading dataset content for easy1_2sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:00, 95.57it/s][A
 20%|██        | 20/100 [00:00<00:00, 86.92it/s][A
 29%|██▉       | 29/100 [00:00<00:00, 87.25it/s][A
 47%|████▋     | 47/100 [00:00<00:00, 119.50it/s][A
 66%|██████▌   | 66/100 [00:00<00:00, 141.28it/s][A
 81%|████████  | 81/100 [00:00<00:00, 129.99it/s][A
100%|██████████| 100/100 [00:00<00:00, 127.99it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 17%|█▋        | 1/6 [01:00<05:00, 60.06s/it]

{'train': {'accuracy': 0.5366666666666666, 'precision': 0.5217965653896962, 'recall': 0.8777777777777778, 'roc_auc': 0.5736395061728394, 'confusion matrix': {'TN': 176, 'FP': 110, 'FN': 724, 'TP': 790}}, 'test': {'accuracy': 0.57, 'precision': 0.5432098765432098, 'recall': 0.88, 'roc_auc': 0.6168, 'confusion matrix': {'TN': 13, 'FP': 6, 'FN': 37, 'TP': 44}}, 'valid': {'accuracy': 0.58, 'precision': 0.5476190476190477, 'recall': 0.92, 'roc_auc': 0.5980000000000001, 'confusion matrix': {'TN': 12, 'FP': 4, 'FN': 38, 'TP': 46}}}
Loading dataset content for easy2_3sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 17/1800 [00:00<00:10, 167.38it/s][A
  2%|▏         | 34/1800 [00:00<00:13, 128.01it/s][A
  3%|▎         | 48/1800 [00:00<00:17, 101.53it/s][A
  3%|▎         | 59/1800 [00:00<00:20, 84.98it/s] [A
  4%|▍         | 76/1800 [00:00<00:16, 106.10it/s][A
  5%|▌         | 93/1800 [00:00<00:14, 121.90it/s][A
  6%|▋         | 114/1800 [00:00<00:11, 145.44it/s][A
  8%|▊         | 135/1800 [00:01<00:10, 162.20it/s][A
  9%|▊         | 157/1800 [00:01<00:09, 177.62it/s][A
 10%|▉         | 178/1800 [00:01<00:08, 186.01it/s][A
 11%|█         | 198/1800 [00:01<00:12, 126.92it/s][A
 12%|█▏        | 214/1800 [00:01<00:16, 96.46it/s] [A
 13%|█▎        | 233/1800 [00:01<00:13, 113.06it/s][A
 14%|█▍        | 252/1800 [00:01<00:12, 128.68it/s][A
 15%|█▌        | 272/1800 [00:02<00:10, 143.46it/s][A
 16%|█▌        | 289/1800 [00:02<00:13, 110.77it/s][A
 17%|█▋        | 307/1800 [00:02<00:12, 123.48it/s][A
 18%|█▊        | 324/1800 

Train dataset statistics : [0.9325084686279297, 0.931618869304657, 0.9276787042617798] [0.16944441199302673, 0.15097908675670624, 0.17676040530204773]
Loading dataset content for easy2_3sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 17/1800 [00:00<00:10, 163.00it/s][A
  2%|▏         | 34/1800 [00:00<00:13, 126.35it/s][A
  3%|▎         | 51/1800 [00:00<00:12, 140.44it/s][A
  4%|▎         | 67/1800 [00:00<00:11, 146.85it/s][A
  5%|▍         | 83/1800 [00:00<00:11, 149.95it/s][A
  6%|▌         | 100/1800 [00:00<00:11, 154.21it/s][A
  6%|▋         | 116/1800 [00:00<00:15, 105.32it/s][A
  7%|▋         | 129/1800 [00:01<00:18, 88.09it/s] [A
  8%|▊         | 148/1800 [00:01<00:15, 108.71it/s][A
  9%|▉         | 167/1800 [00:01<00:12, 126.04it/s][A
 10%|█         | 182/1800 [00:01<00:12, 131.76it/s][A
 11%|█         | 197/1800 [00:01<00:12, 125.32it/s][A
 12%|█▏        | 212/1800 [00:01<00:12, 128.92it/s][A
 13%|█▎        | 226/1800 [00:01<00:15, 101.15it/s][A
 13%|█▎        | 238/1800 [00:02<00:17, 91.50it/s] [A
 14%|█▍        | 254/1800 [00:02<00:14, 105.83it/s][A
 15%|█▍        | 269/1800 [00:02<00:13, 115.97it/s][A
 16%|█▌        | 282/1800

Loading dataset content for easy2_3sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 15%|█▌        | 15/100 [00:00<00:00, 142.97it/s][A
 30%|███       | 30/100 [00:00<00:01, 47.35it/s] [A
 45%|████▌     | 45/100 [00:00<00:00, 63.40it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 66.66it/s][A
 71%|███████   | 71/100 [00:00<00:00, 87.43it/s][A
100%|██████████| 100/100 [00:01<00:00, 87.89it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 2.7223 valid 0.7142
Accuracy cap hit at Step 50 : 0.5 >= 0.01
Training complete
Loading dataset content for easy2_3sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 20/1800 [00:00<00:09, 191.79it/s][A
  2%|▏         | 40/1800 [00:00<00:11, 151.21it/s][A
  3%|▎         | 57/1800 [00:00<00:11, 157.55it/s][A
  4%|▍         | 74/1800 [00:00<00:12, 133.48it/s][A
  5%|▌         | 90/1800 [00:00<00:12, 137.60it/s][A
  6%|▌         | 105/1800 [00:00<00:13, 122.29it/s][A
  7%|▋         | 124/1800 [00:00<00:12, 139.49it/s][A
  8%|▊         | 144/1800 [00:00<00:10, 155.68it/s][A
  9%|▉         | 161/1800 [00:01<00:11, 147.67it/s][A
 10%|█         | 182/1800 [00:01<00:09, 162.42it/s][A
 11%|█▏        | 203/1800 [00:01<00:09, 173.60it/s][A
 12%|█▏        | 223/1800 [00:01<00:08, 180.44it/s][A
 13%|█▎        | 242/1800 [00:01<00:08, 176.22it/s][A
 14%|█▍        | 260/1800 [00:01<00:08, 173.16it/s][A
 15%|█▌        | 278/1800 [00:01<00:08, 171.83it/s][A
 16%|█▋        | 296/1800 [00:01<00:08, 170.65it/s][A
 17%|█▋        | 314/1800 [00:01<00:08, 168.57it/s][A
 18%|█▊        | 332/1800

Loading dataset content for easy2_3sym_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 14%|█▍        | 14/100 [00:00<00:00, 138.89it/s][A
 29%|██▉       | 29/100 [00:00<00:00, 144.37it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 146.73it/s][A
 59%|█████▉    | 59/100 [00:00<00:00, 107.48it/s][A
 75%|███████▌  | 75/100 [00:00<00:00, 121.17it/s][A
100%|██████████| 100/100 [00:00<00:00, 112.68it/s][A


Loading dataset content for easy2_3sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  7%|▋         | 7/100 [00:00<00:01, 60.21it/s][A
 18%|█▊        | 18/100 [00:00<00:00, 87.25it/s][A
 34%|███▍      | 34/100 [00:00<00:00, 117.89it/s][A
 47%|████▋     | 47/100 [00:00<00:00, 121.14it/s][A
 66%|██████▌   | 66/100 [00:00<00:00, 143.93it/s][A
100%|██████████| 100/100 [00:00<00:00, 140.24it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 33%|███▎      | 2/6 [01:55<03:48, 57.22s/it]

{'train': {'accuracy': 0.5038888888888889, 'precision': 0.501952035694367, 'recall': 1.0, 'roc_auc': 0.48005, 'confusion matrix': {'TN': 7, 'FP': 0, 'FN': 893, 'TP': 900}}, 'test': {'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'roc_auc': 0.5528000000000001, 'confusion matrix': {'TN': 0, 'FP': 0, 'FN': 50, 'TP': 50}}, 'valid': {'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'roc_auc': 0.5544, 'confusion matrix': {'TN': 0, 'FP': 0, 'FN': 50, 'TP': 50}}}
Loading dataset content for easy3_2sym_rot_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 18/1800 [00:00<00:10, 172.70it/s][A
  2%|▏         | 36/1800 [00:00<00:10, 169.18it/s][A
  3%|▎         | 53/1800 [00:00<00:10, 166.13it/s][A
  4%|▍         | 70/1800 [00:00<00:10, 164.41it/s][A
  5%|▍         | 87/1800 [00:00<00:12, 142.32it/s][A
  6%|▌         | 105/1800 [00:00<00:11, 150.89it/s][A
  7%|▋         | 121/1800 [00:00<00:11, 146.01it/s][A
  8%|▊         | 137/1800 [00:00<00:11, 148.77it/s][A
  9%|▊         | 154/1800 [00:01<00:10, 153.10it/s][A
  9%|▉         | 170/1800 [00:01<00:11, 137.71it/s][A
 10%|█         | 189/1800 [00:01<00:10, 150.70it/s][A
 11%|█▏        | 206/1800 [00:01<00:10, 154.82it/s][A
 12%|█▏        | 223/1800 [00:01<00:09, 158.69it/s][A
 13%|█▎        | 240/1800 [00:01<00:09, 161.06it/s][A
 14%|█▍        | 257/1800 [00:01<00:09, 162.85it/s][A
 15%|█▌        | 274/1800 [00:01<00:09, 162.99it/s][A
 16%|█▌        | 292/1800 [00:01<00:09, 167.19it/s][A
 17%|█▋        | 309/1800

Train dataset statistics : [0.9333571195602417, 0.9318423271179199, 0.9282010197639465] [0.1673785001039505, 0.15118613839149475, 0.17585933208465576]
Loading dataset content for easy3_2sym_rot_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 17/1800 [00:00<00:10, 164.12it/s][A
  2%|▏         | 34/1800 [00:00<00:10, 165.88it/s][A
  3%|▎         | 51/1800 [00:00<00:10, 165.12it/s][A
  4%|▍         | 68/1800 [00:00<00:10, 165.32it/s][A
  5%|▍         | 85/1800 [00:00<00:10, 161.15it/s][A
  6%|▌         | 102/1800 [00:00<00:12, 134.63it/s][A
  7%|▋         | 118/1800 [00:00<00:11, 141.13it/s][A
  7%|▋         | 133/1800 [00:00<00:14, 118.20it/s][A
  8%|▊         | 149/1800 [00:01<00:12, 128.24it/s][A
  9%|▉         | 165/1800 [00:01<00:11, 136.27it/s][A
 10%|█         | 181/1800 [00:01<00:11, 142.35it/s][A
 11%|█         | 196/1800 [00:01<00:12, 126.16it/s][A
 12%|█▏        | 214/1800 [00:01<00:11, 138.44it/s][A
 13%|█▎        | 233/1800 [00:01<00:10, 152.14it/s][A
 14%|█▍        | 253/1800 [00:01<00:09, 163.12it/s][A
 15%|█▌        | 272/1800 [00:01<00:08, 170.22it/s][A
 16%|█▌        | 292/1800 [00:01<00:08, 176.16it/s][A
 17%|█▋        | 311/1800

Loading dataset content for easy3_2sym_rot_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  6%|▌         | 6/100 [00:00<00:01, 51.05it/s][A
 22%|██▏       | 22/100 [00:00<00:00, 109.41it/s][A
 38%|███▊      | 38/100 [00:00<00:00, 129.03it/s][A
 54%|█████▍    | 54/100 [00:00<00:00, 138.43it/s][A
 68%|██████▊   | 68/100 [00:00<00:00, 109.26it/s][A
 83%|████████▎ | 83/100 [00:00<00:00, 119.83it/s][A
100%|██████████| 100/100 [00:00<00:00, 121.83it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 3.9735 valid 0.7886
Accuracy cap hit at Step 50 : 0.51 >= 0.01
Training complete
Loading dataset content for easy3_2sym_rot_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 20/1800 [00:00<00:09, 194.12it/s][A
  2%|▏         | 40/1800 [00:00<00:09, 177.28it/s][A
  3%|▎         | 58/1800 [00:00<00:26, 66.12it/s] [A
  4%|▍         | 70/1800 [00:01<00:31, 54.88it/s][A
  5%|▍         | 87/1800 [00:01<00:23, 72.77it/s][A
  6%|▌         | 104/1800 [00:01<00:18, 89.93it/s][A
  7%|▋         | 121/1800 [00:01<00:15, 106.10it/s][A
  8%|▊         | 138/1800 [00:01<00:13, 119.96it/s][A
  9%|▊         | 155/1800 [00:01<00:12, 131.41it/s][A
 10%|▉         | 172/1800 [00:01<00:11, 140.47it/s][A
 10%|█         | 189/1800 [00:01<00:11, 146.45it/s][A
 11%|█▏        | 205/1800 [00:02<00:16, 96.90it/s] [A
 12%|█▏        | 222/1800 [00:02<00:14, 111.30it/s][A
 13%|█▎        | 238/1800 [00:02<00:12, 121.84it/s][A
 14%|█▍        | 255/1800 [00:02<00:11, 131.97it/s][A
 15%|█▌        | 272/1800 [00:02<00:10, 139.52it/s][A
 16%|█▌        | 288/1800 [00:02<00:10, 144.79it/s][A
 17%|█▋        | 304/1800 [0

Loading dataset content for easy3_2sym_rot_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 14%|█▍        | 14/100 [00:00<00:00, 137.25it/s][A
 30%|███       | 30/100 [00:00<00:00, 145.46it/s][A
 46%|████▌     | 46/100 [00:00<00:00, 147.98it/s][A
 62%|██████▏   | 62/100 [00:00<00:00, 149.81it/s][A
 78%|███████▊  | 78/100 [00:00<00:00, 151.15it/s][A
100%|██████████| 100/100 [00:00<00:00, 149.60it/s][A


Loading dataset content for easy3_2sym_rot_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  9%|▉         | 9/100 [00:00<00:01, 84.35it/s][A
 25%|██▌       | 25/100 [00:00<00:00, 126.70it/s][A
 38%|███▊      | 38/100 [00:00<00:00, 104.84it/s][A
 49%|████▉     | 49/100 [00:00<00:00, 97.26it/s] [A
 60%|██████    | 60/100 [00:00<00:00, 100.31it/s][A
 80%|████████  | 80/100 [00:00<00:00, 129.11it/s][A
100%|██████████| 100/100 [00:00<00:00, 121.65it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 50%|█████     | 3/6 [02:49<02:48, 56.04s/it]

{'train': {'accuracy': 0.5005555555555555, 'precision': 0.5185185185185185, 'recall': 0.015555555555555555, 'roc_auc': 0.5093839506172839, 'confusion matrix': {'TN': 887, 'FP': 886, 'FN': 13, 'TP': 14}}, 'test': {'accuracy': 0.51, 'precision': 1.0, 'recall': 0.02, 'roc_auc': 0.48760000000000003, 'confusion matrix': {'TN': 50, 'FP': 49, 'FN': 0, 'TP': 1}}, 'valid': {'accuracy': 0.51, 'precision': 1.0, 'recall': 0.02, 'roc_auc': 0.49839999999999995, 'confusion matrix': {'TN': 50, 'FP': 49, 'FN': 0, 'TP': 1}}}
Loading dataset content for hard1_2sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 18/1800 [00:00<00:10, 170.02it/s][A
  2%|▏         | 36/1800 [00:00<00:12, 139.94it/s][A
  3%|▎         | 51/1800 [00:00<00:12, 136.90it/s][A
  4%|▍         | 68/1800 [00:00<00:11, 147.13it/s][A
  5%|▍         | 85/1800 [00:00<00:11, 153.73it/s][A
  6%|▌         | 101/1800 [00:00<00:10, 155.66it/s][A
  7%|▋         | 120/1800 [00:00<00:10, 166.30it/s][A
  8%|▊         | 140/1800 [00:00<00:09, 176.57it/s][A
  9%|▉         | 160/1800 [00:00<00:08, 183.05it/s][A
 10%|█         | 180/1800 [00:01<00:08, 187.92it/s][A
 11%|█         | 199/1800 [00:01<00:10, 153.59it/s][A
 12%|█▏        | 217/1800 [00:01<00:09, 160.42it/s][A
 13%|█▎        | 234/1800 [00:01<00:09, 158.66it/s][A
 14%|█▍        | 251/1800 [00:01<00:09, 159.02it/s][A
 15%|█▍        | 268/1800 [00:01<00:09, 156.10it/s][A
 16%|█▌        | 284/1800 [00:01<00:11, 126.44it/s][A
 17%|█▋        | 301/1800 [00:01<00:10, 136.27it/s][A
 18%|█▊        | 317/1800

Train dataset statistics : [0.9180304408073425, 0.9161907434463501, 0.9121778607368469] [0.17285595834255219, 0.15731018781661987, 0.1815142184495926]
Loading dataset content for hard1_2sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 10/1800 [00:00<00:19, 93.35it/s][A
  1%|          | 20/1800 [00:00<00:35, 50.61it/s][A
  2%|▏         | 38/1800 [00:00<00:19, 89.18it/s][A
  3%|▎         | 58/1800 [00:00<00:14, 121.37it/s][A
  4%|▍         | 78/1800 [00:00<00:11, 143.68it/s][A
  5%|▌         | 98/1800 [00:00<00:10, 159.30it/s][A
  6%|▋         | 117/1800 [00:00<00:10, 167.90it/s][A
  8%|▊         | 135/1800 [00:00<00:09, 171.32it/s][A
  8%|▊         | 153/1800 [00:01<00:10, 160.32it/s][A
  9%|▉         | 170/1800 [00:01<00:10, 148.69it/s][A
 10%|█         | 186/1800 [00:01<00:13, 116.10it/s][A
 11%|█         | 200/1800 [00:01<00:17, 92.78it/s] [A
 12%|█▏        | 214/1800 [00:01<00:15, 102.09it/s][A
 13%|█▎        | 230/1800 [00:01<00:13, 113.69it/s][A
 14%|█▎        | 243/1800 [00:02<00:13, 114.97it/s][A
 14%|█▍        | 256/1800 [00:02<00:13, 113.36it/s][A
 15%|█▍        | 269/1800 [00:02<00:13, 111.72it/s][A
 16%|█▌        | 281/1800 [00

Loading dataset content for hard1_2sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 13%|█▎        | 13/100 [00:00<00:00, 126.66it/s][A
 26%|██▌       | 26/100 [00:00<00:00, 99.99it/s] [A
 37%|███▋      | 37/100 [00:00<00:00, 75.14it/s][A
 52%|█████▏    | 52/100 [00:00<00:00, 95.32it/s][A
 63%|██████▎   | 63/100 [00:00<00:00, 89.84it/s][A
 76%|███████▌  | 76/100 [00:00<00:00, 98.99it/s][A
100%|██████████| 100/100 [00:00<00:00, 104.15it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 2.3045 valid 0.7292
Accuracy cap hit at Step 50 : 0.5 >= 0.01
Training complete
Loading dataset content for hard1_2sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 12/1800 [00:00<00:15, 118.26it/s][A
  1%|▏         | 25/1800 [00:00<00:14, 123.78it/s][A
  2%|▏         | 42/1800 [00:00<00:12, 142.21it/s][A
  3%|▎         | 58/1800 [00:00<00:11, 146.26it/s][A
  4%|▍         | 75/1800 [00:00<00:11, 152.25it/s][A
  5%|▌         | 91/1800 [00:00<00:12, 137.95it/s][A
  6%|▌         | 106/1800 [00:00<00:17, 97.80it/s][A
  7%|▋         | 123/1800 [00:01<00:14, 113.35it/s][A
  8%|▊         | 140/1800 [00:01<00:13, 125.91it/s][A
  9%|▊         | 157/1800 [00:01<00:12, 135.57it/s][A
 10%|▉         | 174/1800 [00:01<00:11, 143.69it/s][A
 11%|█         | 190/1800 [00:01<00:11, 139.66it/s][A
 11%|█▏        | 206/1800 [00:01<00:11, 143.59it/s][A
 12%|█▏        | 222/1800 [00:01<00:10, 148.02it/s][A
 13%|█▎        | 239/1800 [00:01<00:10, 151.75it/s][A
 14%|█▍        | 255/1800 [00:01<00:10, 152.71it/s][A
 15%|█▌        | 271/1800 [00:01<00:09, 153.43it/s][A
 16%|█▌        | 287/1800 [

Loading dataset content for hard1_2sym_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 12%|█▏        | 12/100 [00:00<00:00, 115.10it/s][A
 24%|██▍       | 24/100 [00:00<00:00, 116.01it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 87.28it/s] [A
 46%|████▌     | 46/100 [00:00<00:00, 73.27it/s][A
 60%|██████    | 60/100 [00:00<00:00, 89.76it/s][A
 73%|███████▎  | 73/100 [00:00<00:00, 100.16it/s][A
100%|██████████| 100/100 [00:00<00:00, 103.56it/s][A


Loading dataset content for hard1_2sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 15%|█▌        | 15/100 [00:00<00:00, 148.15it/s][A
 31%|███       | 31/100 [00:00<00:00, 149.78it/s][A
 47%|████▋     | 47/100 [00:00<00:00, 151.14it/s][A
 63%|██████▎   | 63/100 [00:00<00:00, 101.63it/s][A
 75%|███████▌  | 75/100 [00:00<00:00, 90.96it/s] [A
100%|██████████| 100/100 [00:00<00:00, 102.29it/s]A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 67%|██████▋   | 4/6 [03:49<01:54, 57.38s/it]

{'train': {'accuracy': 0.5, 'precision': 0.5, 'recall': 0.9966666666666667, 'roc_auc': 0.5254592592592593, 'confusion matrix': {'TN': 3, 'FP': 3, 'FN': 897, 'TP': 897}}, 'test': {'accuracy': 0.5, 'precision': 0.5, 'recall': 0.98, 'roc_auc': 0.5116, 'confusion matrix': {'TN': 1, 'FP': 1, 'FN': 49, 'TP': 49}}, 'valid': {'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'roc_auc': 0.5932, 'confusion matrix': {'TN': 0, 'FP': 0, 'FN': 50, 'TP': 50}}}
Loading dataset content for hard2_3sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 18/1800 [00:00<00:10, 172.47it/s][A
  2%|▏         | 36/1800 [00:00<00:10, 166.84it/s][A
  3%|▎         | 53/1800 [00:00<00:10, 166.51it/s][A
  4%|▍         | 70/1800 [00:00<00:10, 164.22it/s][A
  5%|▍         | 87/1800 [00:00<00:10, 163.97it/s][A
  6%|▌         | 104/1800 [00:00<00:15, 110.15it/s][A
  7%|▋         | 119/1800 [00:00<00:14, 118.93it/s][A
  8%|▊         | 135/1800 [00:00<00:13, 127.90it/s][A
  8%|▊         | 151/1800 [00:01<00:12, 135.34it/s][A
  9%|▉         | 167/1800 [00:01<00:11, 140.60it/s][A
 10%|█         | 182/1800 [00:01<00:14, 114.82it/s][A
 11%|█         | 198/1800 [00:01<00:12, 124.34it/s][A
 12%|█▏        | 212/1800 [00:01<00:14, 112.71it/s][A
 13%|█▎        | 228/1800 [00:01<00:12, 123.27it/s][A
 13%|█▎        | 242/1800 [00:01<00:12, 125.16it/s][A
 14%|█▍        | 257/1800 [00:01<00:11, 129.95it/s][A
 15%|█▌        | 271/1800 [00:02<00:12, 120.69it/s][A
 16%|█▌        | 286/1800

Train dataset statistics : [0.917290449142456, 0.9159271121025085, 0.9122052192687988] [0.17439891397953033, 0.15750162303447723, 0.18112781643867493]
Loading dataset content for hard2_3sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 18/1800 [00:00<00:10, 177.10it/s][A
  2%|▏         | 36/1800 [00:00<00:27, 64.93it/s] [A
  3%|▎         | 48/1800 [00:00<00:22, 77.63it/s][A
  3%|▎         | 59/1800 [00:00<00:22, 76.31it/s][A
  4%|▍         | 77/1800 [00:00<00:17, 101.11it/s][A
  5%|▌         | 96/1800 [00:00<00:13, 122.61it/s][A
  6%|▌         | 112/1800 [00:01<00:12, 130.99it/s][A
  7%|▋         | 127/1800 [00:01<00:12, 135.55it/s][A
  8%|▊         | 142/1800 [00:01<00:12, 133.12it/s][A
  9%|▊         | 157/1800 [00:01<00:14, 113.18it/s][A
 10%|▉         | 172/1800 [00:01<00:13, 121.60it/s][A
 10%|█         | 186/1800 [00:01<00:13, 117.17it/s][A
 11%|█         | 199/1800 [00:01<00:16, 99.42it/s] [A
 12%|█▏        | 210/1800 [00:01<00:16, 97.27it/s][A
 12%|█▏        | 221/1800 [00:02<00:17, 88.94it/s][A
 13%|█▎        | 236/1800 [00:02<00:15, 102.77it/s][A
 14%|█▍        | 251/1800 [00:02<00:13, 112.01it/s][A
 15%|█▍        | 266/1800 [00:

Loading dataset content for hard2_3sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 11%|█         | 11/100 [00:00<00:00, 108.32it/s][A
 27%|██▋       | 27/100 [00:00<00:00, 134.02it/s][A
 41%|████      | 41/100 [00:00<00:00, 134.68it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 136.58it/s][A
 70%|███████   | 70/100 [00:00<00:00, 138.96it/s][A
 84%|████████▍ | 84/100 [00:00<00:00, 101.38it/s][A
100%|██████████| 100/100 [00:00<00:00, 112.58it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 3.4064 valid 1.2895
Accuracy cap hit at Step 50 : 0.5 >= 0.01
Training complete
Loading dataset content for hard2_3sym_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 17/1800 [00:00<00:11, 161.73it/s][A
  2%|▏         | 34/1800 [00:00<00:13, 131.90it/s][A
  3%|▎         | 51/1800 [00:00<00:12, 144.96it/s][A
  4%|▎         | 66/1800 [00:00<00:11, 145.39it/s][A
  5%|▍         | 83/1800 [00:00<00:11, 153.21it/s][A
  6%|▌         | 99/1800 [00:00<00:12, 135.82it/s][A
  6%|▋         | 114/1800 [00:00<00:15, 108.76it/s][A
  7%|▋         | 126/1800 [00:01<00:16, 100.67it/s][A
  8%|▊         | 137/1800 [00:01<00:17, 92.59it/s] [A
  9%|▊         | 154/1800 [00:01<00:15, 109.24it/s][A
  9%|▉         | 166/1800 [00:01<00:17, 92.28it/s] [A
 10%|█         | 182/1800 [00:01<00:15, 107.38it/s][A
 11%|█         | 199/1800 [00:01<00:13, 121.35it/s][A
 12%|█▏        | 216/1800 [00:01<00:11, 132.42it/s][A
 13%|█▎        | 233/1800 [00:01<00:11, 141.19it/s][A
 14%|█▍        | 249/1800 [00:01<00:10, 146.10it/s][A
 15%|█▍        | 266/1800 [00:02<00:10, 150.77it/s][A
 16%|█▌        | 282/1800 

Loading dataset content for hard2_3sym_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 14%|█▍        | 14/100 [00:00<00:00, 135.56it/s][A
 30%|███       | 30/100 [00:00<00:00, 144.10it/s][A
 45%|████▌     | 45/100 [00:00<00:00, 112.43it/s][A
 61%|██████    | 61/100 [00:00<00:00, 124.15it/s][A
 76%|███████▌  | 76/100 [00:00<00:00, 132.19it/s][A
100%|██████████| 100/100 [00:00<00:00, 133.12it/s][A


Loading dataset content for hard2_3sym_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  9%|▉         | 9/100 [00:00<00:01, 69.30it/s][A
 25%|██▌       | 25/100 [00:00<00:00, 113.76it/s][A
 37%|███▋      | 37/100 [00:00<00:00, 82.17it/s] [A
 47%|████▋     | 47/100 [00:00<00:00, 84.21it/s][A
 62%|██████▏   | 62/100 [00:00<00:00, 102.72it/s][A
 78%|███████▊  | 78/100 [00:00<00:00, 117.44it/s][A
100%|██████████| 100/100 [00:00<00:00, 110.97it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 83%|████████▎ | 5/6 [04:56<01:00, 60.98s/it]

{'train': {'accuracy': 0.5, 'precision': 0.0, 'recall': 0.0, 'roc_auc': 0.49684938271604934, 'confusion matrix': {'TN': 900, 'FP': 900, 'FN': 0, 'TP': 0}}, 'test': {'accuracy': 0.5, 'precision': 0.0, 'recall': 0.0, 'roc_auc': 0.46519999999999995, 'confusion matrix': {'TN': 50, 'FP': 50, 'FN': 0, 'TP': 0}}, 'valid': {'accuracy': 0.5, 'precision': 0.0, 'recall': 0.0, 'roc_auc': 0.5556, 'confusion matrix': {'TN': 50, 'FP': 50, 'FN': 0, 'TP': 0}}}
Loading dataset content for hard3_2sym_rot_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 20/1800 [00:00<00:09, 196.98it/s][A
  2%|▏         | 41/1800 [00:00<00:08, 200.10it/s][A
  3%|▎         | 62/1800 [00:00<00:09, 181.21it/s][A
  4%|▍         | 81/1800 [00:00<00:09, 176.21it/s][A
  6%|▌         | 99/1800 [00:00<00:09, 171.91it/s][A
  6%|▋         | 117/1800 [00:00<00:11, 145.19it/s][A
  7%|▋         | 133/1800 [00:00<00:11, 142.59it/s][A
  8%|▊         | 148/1800 [00:01<00:14, 116.27it/s][A
  9%|▉         | 161/1800 [00:01<00:15, 104.72it/s][A
 10%|▉         | 178/1800 [00:01<00:13, 118.40it/s][A
 11%|█         | 191/1800 [00:01<00:14, 110.50it/s][A
 12%|█▏        | 208/1800 [00:01<00:12, 124.78it/s][A
 12%|█▎        | 225/1800 [00:01<00:11, 135.77it/s][A
 13%|█▎        | 240/1800 [00:01<00:14, 105.74it/s][A
 14%|█▍        | 253/1800 [00:02<00:18, 85.15it/s] [A
 15%|█▍        | 266/1800 [00:02<00:16, 93.64it/s][A
 16%|█▌        | 280/1800 [00:02<00:15, 97.99it/s][A
 16%|█▌        | 291/1800 [

Train dataset statistics : [0.9186922907829285, 0.9168760180473328, 0.9123098850250244] [0.17197488248348236, 0.15628032386302948, 0.18190833926200867]
Loading dataset content for hard3_2sym_rot_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 11/1800 [00:00<00:18, 97.25it/s][A
  1%|          | 21/1800 [00:00<00:30, 58.84it/s][A
  2%|▏         | 28/1800 [00:00<00:36, 48.33it/s][A
  2%|▏         | 37/1800 [00:00<00:29, 58.97it/s][A
  3%|▎         | 53/1800 [00:00<00:20, 86.10it/s][A
  4%|▍         | 69/1800 [00:00<00:16, 105.66it/s][A
  5%|▍         | 85/1800 [00:00<00:14, 119.71it/s][A
  6%|▌         | 100/1800 [00:01<00:13, 127.91it/s][A
  6%|▋         | 114/1800 [00:01<00:17, 94.65it/s] [A
  7%|▋         | 130/1800 [00:01<00:15, 108.93it/s][A
  8%|▊         | 146/1800 [00:01<00:13, 120.52it/s][A
  9%|▉         | 162/1800 [00:01<00:12, 128.86it/s][A
 10%|▉         | 176/1800 [00:01<00:15, 107.86it/s][A
 10%|█         | 189/1800 [00:01<00:17, 93.20it/s] [A
 11%|█         | 200/1800 [00:02<00:17, 92.35it/s][A
 12%|█▏        | 214/1800 [00:02<00:15, 101.29it/s][A
 12%|█▎        | 225/1800 [00:02<00:16, 94.88it/s] [A
 13%|█▎        | 236/1800 [00:02<

Loading dataset content for hard3_2sym_rot_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:00, 97.70it/s][A
 25%|██▌       | 25/100 [00:00<00:00, 126.09it/s][A
 40%|████      | 40/100 [00:00<00:00, 135.68it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 140.39it/s][A
 70%|███████   | 70/100 [00:00<00:00, 127.43it/s][A
 83%|████████▎ | 83/100 [00:00<00:00, 95.49it/s] [A
100%|██████████| 100/100 [00:00<00:00, 102.77it/s]A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


EPOCH 1:
LOSS train 3.1331 valid 0.7232
Accuracy cap hit at Step 50 : 0.52 >= 0.01
Training complete
Loading dataset content for hard3_2sym_rot_train.csv



  0%|          | 0/1800 [00:00<?, ?it/s][A
  1%|          | 16/1800 [00:00<00:11, 154.67it/s][A
  2%|▏         | 32/1800 [00:00<00:20, 87.42it/s] [A
  3%|▎         | 47/1800 [00:00<00:16, 106.75it/s][A
  4%|▎         | 63/1800 [00:00<00:14, 121.64it/s][A
  4%|▍         | 79/1800 [00:00<00:12, 132.80it/s][A
  5%|▌         | 94/1800 [00:00<00:16, 104.25it/s][A
  6%|▌         | 106/1800 [00:01<00:21, 79.42it/s][A
  7%|▋         | 122/1800 [00:01<00:17, 95.57it/s][A
  7%|▋         | 134/1800 [00:01<00:19, 87.60it/s][A
  8%|▊         | 145/1800 [00:01<00:25, 65.27it/s][A
  9%|▊         | 154/1800 [00:01<00:25, 63.43it/s][A
  9%|▉         | 169/1800 [00:01<00:20, 79.12it/s][A
 10%|▉         | 179/1800 [00:02<00:19, 81.67it/s][A
 10%|█         | 189/1800 [00:02<00:20, 79.66it/s][A
 11%|█         | 202/1800 [00:02<00:17, 90.15it/s][A
 12%|█▏        | 212/1800 [00:02<00:17, 90.85it/s][A
 13%|█▎        | 226/1800 [00:02<00:15, 101.92it/s][A
 13%|█▎        | 241/1800 [00:02<00:

Loading dataset content for hard3_2sym_rot_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  6%|▌         | 6/100 [00:00<00:01, 50.04it/s][A
 12%|█▏        | 12/100 [00:00<00:01, 46.98it/s][A
 21%|██        | 21/100 [00:00<00:01, 60.27it/s][A
 28%|██▊       | 28/100 [00:00<00:01, 61.63it/s][A
 35%|███▌      | 35/100 [00:00<00:01, 54.67it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 59.16it/s][A
 51%|█████     | 51/100 [00:00<00:01, 48.26it/s][A
 58%|█████▊    | 58/100 [00:01<00:00, 50.30it/s][A
 65%|██████▌   | 65/100 [00:01<00:00, 54.14it/s][A
 80%|████████  | 80/100 [00:01<00:00, 77.76it/s][A
100%|██████████| 100/100 [00:01<00:00, 64.69it/s][A


Loading dataset content for hard3_2sym_rot_valid.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 15%|█▌        | 15/100 [00:00<00:00, 147.91it/s][A
 30%|███       | 30/100 [00:00<00:00, 83.99it/s] [A
 45%|████▌     | 45/100 [00:00<00:00, 104.94it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 97.07it/s] [A
 69%|██████▉   | 69/100 [00:00<00:00, 98.55it/s][A
 85%|████████▌ | 85/100 [00:00<00:00, 114.38it/s][A
100%|██████████| 100/100 [00:00<00:00, 106.02it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 6/6 [06:00<00:00, 60.09s/it]

{'train': {'accuracy': 0.5316666666666666, 'precision': 0.5288753799392097, 'recall': 0.58, 'roc_auc': 0.5315049382716049, 'confusion matrix': {'TN': 435, 'FP': 378, 'FN': 465, 'TP': 522}}, 'test': {'accuracy': 0.6, 'precision': 0.6, 'recall': 0.6, 'roc_auc': 0.6168, 'confusion matrix': {'TN': 30, 'FP': 20, 'FN': 20, 'TP': 30}}, 'valid': {'accuracy': 0.52, 'precision': 0.5185185185185185, 'recall': 0.56, 'roc_auc': 0.4696, 'confusion matrix': {'TN': 24, 'FP': 22, 'FN': 26, 'TP': 28}}}



