In [1]:
from exp.utils import *
from exp.models import *
from exp.losses import *
from tqdm.notebook import tqdm
from multiprocessing import Pool

import torch
import torch.nn as NN
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [2]:
picked_labels = ["Atelectasis", "Cardiomegaly", "Pneumonia"]
already_pretrained_labels = ["Cardiomegaly", "Pneumonia"]
already_pretrained_lrs = [1e-10, 1e-10]
already_pretrained = {k:v for k,v in zip(already_pretrained_labels, already_pretrained_lrs)}
already_trained = ["Atelectasis"]
train_labels = [l for l in picked_labels if l not in already_trained]

In [3]:
for label in sorted(train_labels):
    print(f"Training model to classify '{label}'")
    
    # Seed
    seed = 92
    seed_everything(seed)
    
    # Inital setup
    model_name = f"DenseNet121_v1_{label}"#f"sam_densenet_v1_{label}"
    model_type = "densenet"
    bs = 16
    lr = 1e-3
    epochs = 50
    image_size = (224, 224)
    device = get_device()
    labels = get_labels()
    
    # Load data
    train_df, valid_df, test_df = get_dataframes(include_labels=labels, 
                                                 small=False)
    print(train_df.shape, valid_df.shape, test_df.shape)
    train_df = get_binary_df(label, train_df)
    valid_df = get_binary_df(label, valid_df)
    test_df = get_binary_df(label, test_df)
    
    # Compute label weights
    train_label = train_df[[label]].values
    neg_weights, pos_weights = compute_class_freqs(train_label)
    neg_weights, pos_weights = torch.Tensor(neg_weights), torch.Tensor(pos_weights)
    print(neg_weights, pos_weights)
    
    # Get transforms
    train_tfs, test_tfs = get_transforms(image_size=image_size)
    
    # Create datasets
    train_ds = CRX8_Data(train_df, get_image_path(), label, image_size=image_size, transforms=train_tfs)
    valid_ds = CRX8_Data(valid_df, get_image_path(), label, image_size=image_size, transforms=test_tfs)
    test_ds  = CRX8_Data(test_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)
    
    # Create dataloaders
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
    valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False)
    test_dl  = DataLoader(test_ds,  batch_size=bs, shuffle=False)
    dataloaders = {
        "train": train_dl,
        "val": valid_dl,
        "test": test_dl
    }
    
    # Load imagenet-pretrained model
    if label in list(already_pretrained.keys()):
        model = load_model(model_name)
        lr = already_pretrained[label]
    else:
        model = pretrained_densenet121()
    model = model.to(device)
    
    
    
    # Get criterion and optimizer
    criterion = get_weighted_loss_with_logits(pos_weights.to(device), 
                                              neg_weights.to(device))
    sam_optimizer = SAM(model.parameters(), torch.optim.Adam, lr=lr)
    
    # Train model
    model, history = fit(model, criterion, sam_optimizer, 
                         dataloaders, model_name, epochs, 
                         lr, sam=True, metric="loss", patience=1)

FERTIG()

Training model to classify 'Cardiomegaly'
Using the GPU!




(69219, 24) (17305, 24) (25596, 24)
tensor([0.0198]) tensor([0.9802])
Epoch 1:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.017, Acc: 0.697, AUROC: 0.875


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.017, Acc: 0.683, AUROC: 0.871
Saved model with loss 0.0170
Epoch 2:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.017, Acc: 0.695, AUROC: 0.876


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.017, Acc: 0.656, AUROC: 0.872
Saved model with loss 0.0169
Epoch 3:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.017, Acc: 0.695, AUROC: 0.876


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.017, Acc: 0.693, AUROC: 0.872
Lowered lr to 1.0000000000000001e-11
Epoch 4:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.017, Acc: 0.696, AUROC: 0.875


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.017, Acc: 0.649, AUROC: 0.871
Lowered lr to 1.0000000000000002e-12
Epoch 5:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.017, Acc: 0.696, AUROC: 0.875


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.017, Acc: 0.687, AUROC: 0.871
Lowered lr to 1.0000000000000002e-13
Epoch 6:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.017, Acc: 0.696, AUROC: 0.876


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.017, Acc: 0.701, AUROC: 0.871
Lowered lr to 1.0000000000000002e-14
Learning rate is basically zero. Stopping training.
Training model to classify 'Pneumonia'
Using the GPU!
(69219, 24) (17305, 24) (25596, 24)
tensor([0.0099]) tensor([0.9901])
Epoch 1:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.612, AUROC: 0.588


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.023, Acc: 0.753, AUROC: 0.576
Saved model with loss 0.0234
Epoch 2:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.612, AUROC: 0.591


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.019, Acc: 0.772, AUROC: 0.582
Saved model with loss 0.0192
Epoch 3:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.613, AUROC: 0.593


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.018, Acc: 0.784, AUROC: 0.585
Saved model with loss 0.0176
Epoch 4:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.613, AUROC: 0.586


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.021, Acc: 0.761, AUROC: 0.573
Lowered lr to 1.0000000000000001e-11
Epoch 5:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.613, AUROC: 0.588


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.022, Acc: 0.779, AUROC: 0.576
Lowered lr to 1.0000000000000002e-12
Epoch 6:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.613, AUROC: 0.582


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.020, Acc: 0.781, AUROC: 0.578
Lowered lr to 1.0000000000000002e-13
Epoch 7:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4327.0), HTML(value='')))


Train: Loss: 0.014, Acc: 0.614, AUROC: 0.591


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1082.0), HTML(value='')))


Val: Loss: 0.021, Acc: 0.753, AUROC: 0.574
Lowered lr to 1.0000000000000002e-14
Learning rate is basically zero. Stopping training.
FERTIG! :D


In [None]:
assert False

In [None]:
label = "Cardiomegaly"
model_name = f"sam_densenet_v1_{label}"
model_type = "densenet"
bs = 16
lr = 1e-3
epochs = 50
image_size = (224, 224)
device = get_device()
labels = get_labels()

In [None]:
train_df, valid_df, test_df = get_dataframes(include_labels=labels, small=True, small_fraction=0.05)
train_df.shape, valid_df.shape, test_df.shape

In [None]:
train_df = get_binary_df(label, train_df)
valid_df = get_binary_df(label, valid_df)
test_df = get_binary_df(label, test_df)

In [None]:
train_label = train_df[[label]].values
neg_weights, pos_weights = compute_class_freqs(train_label)
neg_weights, pos_weights = torch.Tensor(neg_weights), torch.Tensor(pos_weights)
neg_weights, pos_weights

In [None]:
train_tfs, test_tfs = get_transforms(image_size=image_size)

In [None]:
train_ds = CRX8_Data(train_df, get_image_path(), label, image_size=image_size, transforms=train_tfs)
valid_ds = CRX8_Data(valid_df, get_image_path(), label, image_size=image_size, transforms=test_tfs)
test_ds  = CRX8_Data(test_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)

In [None]:
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False)
test_dl  = DataLoader(test_ds,  batch_size=bs, shuffle=False)

dataloaders = {
    "train": train_dl,
    "val": valid_dl,
    "test": test_dl
}

In [None]:
model = pretrained_densenet121()
model = model.to(device)

In [None]:
criterion = get_weighted_loss_with_logits(pos_weights.to(device), neg_weights.to(device))
sam_optimizer = SAM(model.parameters(), torch.optim.Adam, lr=lr)

In [None]:
model, history = fit(model, criterion, sam_optimizer, 
                     dataloaders, model_name, epochs, 
                     lr, sam=True, metric="loss")