In [21]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import ml_collections 
import deepchest
import os

In [22]:
import copy
import time

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

In [23]:
config = ml_collections.ConfigDict()

config.batch_size = 32
config.num_steps = 300

# See preprocessing.py, if you replace with ";" no preprocessing is done
config.preprocessing_train_eval = "independent_dropout(.2);"

config.use_validation_split = False

# If validation split is false, then train will have 4/5 of data and test 1/5
# If validation split is true, then train will have 3/5 of data, test 1/5 and val 1/5
config.num_folds = 5

# gpu workers
config.num_workers = 0

# dataset
config.images_directory = "dataset/images.ds1/"
config.labels_file = "dataset/labels.ds1/diagnostic.csv"

# Fold seed
config.random_state = 0

# Where the indices are saved
config.save_dir = "model_saved/"
config.export_folds_indices_file = "indices.csv"

# Don't modify these (should not have been in the config)
config.test_fold_index = 0
config.delta_from_test_index_to_validation_index = 1

In [None]:
model = Net()
checkpoint = torch.load(os.path.join(
            config.save_dir,
            f"best_model_smaller_resnet_sigmoid_epoch{0}_test_fold_index{config.test_fold_index}.ds1",
        ))

In [4]:
import glob
for ds in [1, 2]:
    if ds == 1:
        lb = [1, 2]
    else:
        lb = [2]
    for l in lb:
        for epoch in range(4):
            checkpoint_path = os.path.join(
                config.save_dir,
                f"best_model_smaller_resnet_sigmoid_epoch{epoch}_test_fold_index*{'.ds'+str(ds)+f'-{l}'}"
                # f"best_model_resnet18_sigmoid_epoch{epoch}_test_fold_index*{'.ds'+str(ds)+'-1'}"
                # {'.ds'+str(ds)+'-1' if ds == 1 else ''}",
            )
            list_checkpoints = glob.glob(checkpoint_path)
            accs = []
            for file in list_checkpoints:
                checkpoint = torch.load(file)
                accs.append(checkpoint['acc'])
            min_acc = min([(a.data, i) for i, a in enumerate(accs)])[0]
            min_idx = min([(a.data, i) for i, a in enumerate(accs)])[1]
            max_acc = max([(a.data, i) for i, a in enumerate(accs)])[0]
            max_idx = max([(a.data, i) for i, a in enumerate(accs)])[1]
            print(f'dataset {ds} label {l} epoch {epoch}: {torch.mean(torch.Tensor(accs))}\t(min {min_acc} index {min_idx}, max {max_acc} index {max_idx})')
        print('+++++++++++++')

dataset 1 label 1 epoch 0: 0.7395833134651184	(min 0.5625 index 2, max 0.84375 index 1)
dataset 1 label 1 epoch 1: 0.78125	(min 0.65625 index 0, max 0.84375 index 2)
dataset 1 label 1 epoch 2: 0.78125	(min 0.65625 index 0, max 0.90625 index 1)
dataset 1 label 1 epoch 3: 0.78125	(min 0.65625 index 2, max 0.90625 index 1)
+++++++++++++
dataset 1 label 2 epoch 0: 0.5	(min 0.46875 index 0, max 0.5625 index 3)
dataset 1 label 2 epoch 1: 0.5	(min 0.46875 index 0, max 0.5625 index 1)
dataset 1 label 2 epoch 2: 0.48750001192092896	(min 0.46875 index 2, max 0.53125 index 0)
dataset 1 label 2 epoch 3: 0.48124998807907104	(min 0.4375 index 3, max 0.53125 index 4)
+++++++++++++
dataset 2 label 2 epoch 0: 0.5625	(min 0.5 index 1, max 0.625 index 2)
dataset 2 label 2 epoch 1: 0.5520833134651184	(min 0.5 index 0, max 0.59375 index 1)
dataset 2 label 2 epoch 2: 0.5416666865348816	(min 0.5 index 1, max 0.5625 index 2)
dataset 2 label 2 epoch 3: 0.5416666865348816	(min 0.5 index 0, max 0.5625 index 2)
+

In [None]:
from model import model
checkpoint = torch.load(os.path.join(
            config.save_dir,
            f"best_model_smaller_resnet_sigmoid_epoch3_test_fold_index3.ds1-2",
        ))

net = model.get_smaller_resnet()
net.load_state_dict(checkpoint['model_state_dict'])
_ = net.eval()

In [24]:
from model import model
import glob
import re
device = torch.device("cpu")

results = []
for idx in range(config.num_folds):
    print('idx', idx)
    config.test_fold_index = int(idx)
    (
        train_loader,
        test_loader,
        _,
    ) = deepchest.dataset.get_data_loaders(config=config)
    
    checkpoint_path = os.path.join(
        config.save_dir,
        f"best_model_smaller_resnet_sigmoid_epoch*_test_fold_index{idx}*"
    )
    
    list_checkpoints = glob.glob(checkpoint_path)
    for file in list_checkpoints:
        checkpoint = torch.load(file)

        net = model.get_smaller_resnet()
        net.load_state_dict(checkpoint['model_state_dict'])
        _ = net.eval()

        label_names = deepchest.utils.get_label_names(config.labels_file)
        scores, labels = deepchest.utils.model_evaluation(net, test_loader, device)
        train_metrics = deepchest.utils.compute_metrics(labels, scores, label_names)
        results.append((file, train_metrics))
        print(file, '\n', train_metrics['roc_auc'],)
        print("===================")
    print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
    torch.save(results, 'results-ds1-1')
    del train_loader
    del test_loader

idx 0
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index0.ds1-2 
 0.9686274509803923
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index0.ds2-2 
 0.9529411764705883
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index0.ds2-2 
 0.9372549019607843
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index0.ds2-2 
 0.9529411764705883
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index0.ds2-2 
 0.9450980392156864
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index0.ds1-1 
 0.9803921568627452
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index0.ds1-1 
 0.9568627450980393
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index0.ds1-2 
 0.9450980392156862
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index0.ds1-1 
 0.9490196078431373
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index0.ds1-1 
 0.9725490196078432
model_saved/best_model_smaller

model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index1.ds1-1 
 0.8745098039215686
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index1.ds1-2 
 0.8627450980392156
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index1.ds1-1 
 0.8941176470588237
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index1.ds1-2 
 0.8823529411764706
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index1.ds1-2 
 0.8509803921568627
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index1.ds1-2 
 0.8745098039215686
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index1.ds1-1 
 0.8627450980392157
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index1.ds1-1 
 0.8862745098039215
+++++++++++++++++++++++++++++++++++++++++++++++++++
idx 2
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index2.ds1-2 
 0.6985294117647058
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index2.ds1-1 
 0.6654411764705883
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index2.ds1-2 
 0.6470588235294117
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index2.ds1-2 
 0.7316176470588235
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index2.ds1-2 
 0.6985294117647058
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index2.ds1-1 
 0.6911764705882353
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index2.ds1-1 
 0.7169117647058824
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index2.ds1-1 
 0.6470588235294117
+++++++++++++++++++++++++++++++++++++++++++++++++++
idx 3
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index3.ds1-2 
 0.8639705882352942
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index3.ds2-2 
 0.8602941176470589
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index3.ds1-2 
 0.8455882352941176
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index3.ds2-2 
 0.8676470588235294
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index3.ds2-2 
 0.8455882352941176
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index3.ds1-2 
 0.8345588235294117
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index3.ds1-2 
 0.8345588235294118
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index3.ds2-2 
 0.8676470588235294
+++++++++++++++++++++++++++++++++++++++++++++++++++
idx 4
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index4.ds1-2 
 0.8515625
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index4.ds1-2 
 0.85546875
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index4.ds2-2 
 0.8828125
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index4.ds2-2 
 0.91015625
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index4.ds1-2 
 0.8671875
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index4.ds2-2 
 0.8515625
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index4.ds2-2 
 0.87109375
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index4.ds1-2 
 0.859375
+++++++++++++++++++++++++++++++++++++++++++++++++++


In [4]:
from model import model
import glob
import re
device = torch.device("cpu")

results = torch.load('results')
for idx in range(2, config.num_folds):
    print('idx', idx)
    config.test_fold_index = int(idx)
    (
        train_loader,
        test_loader,
        _,
    ) = deepchest.dataset.get_data_loaders(config=config)
    
    checkpoint_path = os.path.join(
        config.save_dir,
        f"best_model_smaller_resnet_sigmoid_epoch*_test_fold_index{idx}*"
    )
    
    list_checkpoints = glob.glob(checkpoint_path)
    for file in list_checkpoints:
        checkpoint = torch.load(file)

        net = model.get_smaller_resnet()
        net.load_state_dict(checkpoint['model_state_dict'])
        _ = net.eval()

        label_names = deepchest.utils.get_label_names(config.labels_file)
        scores, labels = deepchest.utils.model_evaluation(net, test_loader, device)
        train_metrics = deepchest.utils.compute_metrics(labels, scores, label_names)
        results.append((file, train_metrics))
        print(file, '\n', train_metrics['roc_auc'],)
        print("===================")
    print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
    torch.save(results, 'results')
    del train_loader
    del test_loader

idx 2
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index2.ds1-2 
 0.5294117647058822
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index2.ds1-1 
 0.5036764705882353
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index2.ds1-2 
 0.5404411764705883
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index2.ds1-2 
 0.5294117647058824
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index2.ds1-2 
 0.5367647058823529
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index2.ds1-1 
 0.5110294117647058
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index2.ds1-1 
 0.5110294117647058
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index2.ds1-1 
 0.5036764705882353
+++++++++++++++++++++++++++++++++++++++++++++++++++
idx 3
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index3.ds1-2 
 0.8272058823529411
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index3.ds2-2 
 0.8235294117647058
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index3.ds1-2 
 0.8272058823529411
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index3.ds2-2 
 0.8419117647058824
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index3.ds2-2 
 0.8566176470588235
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index3.ds1-2 
 0.8235294117647058
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index3.ds1-2 
 0.8308823529411764
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index3.ds2-2 
 0.8308823529411764
+++++++++++++++++++++++++++++++++++++++++++++++++++
idx 4
Split infos:


model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index4.ds1-2 
 0.86328125
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index4.ds1-2 
 0.87109375
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index4.ds2-2 
 0.85546875
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index4.ds2-2 
 0.859375
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index4.ds1-2 
 0.8671875
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index4.ds2-2 
 0.8671875
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index4.ds2-2 
 0.87109375
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index4.ds1-2 
 0.8671875
+++++++++++++++++++++++++++++++++++++++++++++++++++


In [1]:
import torch 
results = torch.load('results')


In [7]:
for idx in range(5):
    print('idx', idx)
    for r in results:
        if f'index{idx}' in r[0] and 'ds2-2' in r[0]:
            print(r[0], r[1]['balanced_accuracy'])

idx 0
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index0.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index0.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index0.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index0.ds2-2 0.5
idx 1
idx 2
idx 3
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index3.ds2-2 0.49816176470588236
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index3.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index3.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index3.ds2-2 0.5
idx 4
model_saved/best_model_smaller_resnet_sigmoid_epoch2_test_fold_index4.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch3_test_fold_index4.ds2-2 0.5
model_saved/best_model_smaller_resnet_sigmoid_epoch1_test_fold_index4.ds2-2 0.53125
model_saved/best_model_smaller_resnet_sigmoid_epoch0_test_fold_index4.

In [6]:
train_metrics.keys()

dict_keys(['loss', 'ece', 'true_positive', 'true_negative', 'false_negative', 'false_positive', 'balanced_accuracy', 'false_positive_rate', 'true_positive_rate', 'roc_auc', 'roc', 'mean_predicted_value', 'fraction_of_positives', 'ece_curve', 'labels', 'logits'])