In [1]:
from exp.utils import *
from exp.models import *
from exp.losses import *
from tqdm.notebook import tqdm
from multiprocessing import Pool

import torch
import torch.nn as NN
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [2]:
architecture = "MNASNET1_preclinic_v1"
model_name = f"{architecture}"#f"sam_densenet_v1_{label}"

In [3]:
image_size = (224, 224)
bs = 128
lr = 1e-4
epochs = 50
patience = 1
rho = 0.05
with_reset = False
device = get_device()

Using the GPU!


In [4]:
labels = get_labels()
label = PATHOLOGIC

In [5]:
train_df, valid_df, test_df = get_dataframes(include_labels=labels, 
                                                 small=False)

In [6]:
train_df = get_preclinic_df(train_df)
valid_df = get_preclinic_df(valid_df)
test_df = get_preclinic_df(test_df)

In [7]:
train_label = train_df[[label]].values
neg_weights, pos_weights = compute_class_freqs(train_label)
neg_weights, pos_weights = torch.Tensor(neg_weights), torch.Tensor(pos_weights)
print(neg_weights, pos_weights)

tensor([0.5862]) tensor([0.4138])


In [8]:
train_tfs, test_tfs = get_transforms(image_size=image_size)

In [9]:
train_ds = CRX8_Data(train_df, get_image_path(), label, image_size=image_size, transforms=train_tfs)
valid_ds = CRX8_Data(valid_df, get_image_path(), label, image_size=image_size, transforms=test_tfs)
test_ds  = CRX8_Data(test_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)

In [10]:
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False)
test_dl  = DataLoader(test_ds,  batch_size=bs, shuffle=False)
dataloaders = {
    "train": train_dl,
    "val": valid_dl,
    "test": test_dl
}

In [11]:
model = pretrained_mnasnet0_5()
model = model.to(device)

In [12]:
criterion = get_weighted_loss_with_logits(pos_weights.to(device), 
                                          neg_weights.to(device))
sam_optimizer = SAM(model.parameters(), torch.optim.Adam, lr=lr, rho=rho)

In [13]:
model, history = fit(model, criterion, sam_optimizer, 
                     dataloaders, model_name, epochs, 
                     lr, sam=True, with_reset=with_reset, 
                     metric="loss", patience=patience)

Epoch 1:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.322, Acc: 0.622, AUROC: 0.657


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.404, Acc: 0.654, AUROC: 0.695
Saved model with loss 0.4043
Epoch 2:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.301, Acc: 0.664, AUROC: 0.717


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.368, Acc: 0.671, AUROC: 0.719
Saved model with loss 0.3679
Epoch 3:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.296, Acc: 0.674, AUROC: 0.729


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.334, Acc: 0.683, AUROC: 0.728
Saved model with loss 0.3342
Epoch 4:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.293, Acc: 0.679, AUROC: 0.736


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.316, Acc: 0.683, AUROC: 0.731
Saved model with loss 0.3160
Epoch 5:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.292, Acc: 0.683, AUROC: 0.740


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.308, Acc: 0.685, AUROC: 0.733
Saved model with loss 0.3077
Epoch 6:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.290, Acc: 0.685, AUROC: 0.743


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.301, Acc: 0.690, AUROC: 0.733
Saved model with loss 0.3010
Epoch 7:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.289, Acc: 0.687, AUROC: 0.747


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.299, Acc: 0.689, AUROC: 0.734
Saved model with loss 0.2988
Epoch 8:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.288, Acc: 0.688, AUROC: 0.748


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.300, Acc: 0.691, AUROC: 0.735
Lowered lr to 1e-05
Epoch 9:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.287, Acc: 0.690, AUROC: 0.751


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.298, Acc: 0.692, AUROC: 0.736
Saved model with loss 0.2983
Epoch 10:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.286, Acc: 0.691, AUROC: 0.754


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.297, Acc: 0.695, AUROC: 0.738
Saved model with loss 0.2966
Epoch 11:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.285, Acc: 0.693, AUROC: 0.756


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.297, Acc: 0.697, AUROC: 0.738
Saved model with loss 0.2966
Epoch 12:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.284, Acc: 0.696, AUROC: 0.758


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.296, Acc: 0.683, AUROC: 0.740
Saved model with loss 0.2956
Epoch 13:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.283, Acc: 0.695, AUROC: 0.760


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.294, Acc: 0.690, AUROC: 0.743
Saved model with loss 0.2942
Epoch 14:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.282, Acc: 0.695, AUROC: 0.761


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.294, Acc: 0.697, AUROC: 0.745
Saved model with loss 0.2939
Epoch 15:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.282, Acc: 0.697, AUROC: 0.763


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.294, Acc: 0.693, AUROC: 0.746
Saved model with loss 0.2935
Epoch 16:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.281, Acc: 0.699, AUROC: 0.765


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.292, Acc: 0.692, AUROC: 0.748
Saved model with loss 0.2918
Epoch 17:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.279, Acc: 0.699, AUROC: 0.767


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.292, Acc: 0.700, AUROC: 0.750
Lowered lr to 1.0000000000000002e-06
Epoch 18:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.279, Acc: 0.701, AUROC: 0.769


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.291, Acc: 0.704, AUROC: 0.751
Saved model with loss 0.2910
Epoch 19:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.278, Acc: 0.702, AUROC: 0.771


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.291, Acc: 0.703, AUROC: 0.751
Saved model with loss 0.2908
Epoch 20:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.277, Acc: 0.704, AUROC: 0.772


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.290, Acc: 0.706, AUROC: 0.751
Saved model with loss 0.2900
Epoch 21:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.276, Acc: 0.705, AUROC: 0.775


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.290, Acc: 0.701, AUROC: 0.751
Saved model with loss 0.2898
Epoch 22:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.275, Acc: 0.705, AUROC: 0.776


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.291, Acc: 0.702, AUROC: 0.751
Lowered lr to 1.0000000000000002e-07
Epoch 23:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.274, Acc: 0.707, AUROC: 0.779


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.290, Acc: 0.703, AUROC: 0.751
Lowered lr to 1.0000000000000002e-08
Epoch 24:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.274, Acc: 0.707, AUROC: 0.780


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.290, Acc: 0.702, AUROC: 0.751
Lowered lr to 1.0000000000000003e-09
Epoch 25:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.272, Acc: 0.710, AUROC: 0.782


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.290, Acc: 0.707, AUROC: 0.751
Lowered lr to 1.0000000000000003e-10
Epoch 26:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.272, Acc: 0.709, AUROC: 0.783


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.291, Acc: 0.705, AUROC: 0.752
Lowered lr to 1.0000000000000003e-11
Epoch 27:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.271, Acc: 0.712, AUROC: 0.785


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.291, Acc: 0.707, AUROC: 0.751
Lowered lr to 1.0000000000000002e-12
Epoch 28:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.270, Acc: 0.712, AUROC: 0.786


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.291, Acc: 0.706, AUROC: 0.751
Lowered lr to 1.0000000000000002e-13
Epoch 29:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=541.0), HTML(value='')))


Train: Loss: 0.269, Acc: 0.713, AUROC: 0.788


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Val: Loss: 0.292, Acc: 0.703, AUROC: 0.751
Lowered lr to 1.0000000000000002e-14
Learning rate is basically zero. Stopping training.


In [14]:
FERTIG()

FERTIG! :D
