In [None]:
from exp.utils import *
from exp.models import *
from exp.losses import *
from tqdm.notebook import tqdm
from multiprocessing import Pool

import torch
import torch.nn as NN
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

In [None]:
seed = 92
seed_everything(seed)

In [None]:
label = "Cardiomegaly"
model_name = f"chexnet_way_training_v1_{label}"
model_type = "densenet"
bs = 16
lr = 1e-3
epochs = 50
image_size = (224, 224)
device = get_device()
labels = get_labels()

In [None]:
train_df, valid_df, test_df = get_dataframes(include_labels=labels, small=True, small_fraction=0.125)
train_df.shape, valid_df.shape, test_df.shape

In [None]:
train_df = get_binary_df(label, train_df)
valid_df = get_binary_df(label, valid_df)
test_df = get_binary_df(label, test_df)

In [None]:
train_label = train_df[[label]].values
neg_weights, pos_weights = compute_class_freqs(train_label)
neg_weights, pos_weights = torch.Tensor(neg_weights), torch.Tensor(pos_weights)
neg_weights, pos_weights

In [None]:
train_tfs = transforms.Compose([
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        transforms.Resize(image_size),
        transforms.RandomHorizontalFlip(p=0.5)
    ])

_, test_tfs = get_transforms(image_size=image_size)

In [None]:
train_ds = CRX8_Data(train_df, get_image_path(), label, image_size=image_size, transforms=train_tfs)
valid_ds = CRX8_Data(valid_df, get_image_path(), label, image_size=image_size, transforms=test_tfs)
test_ds  = CRX8_Data(test_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)

In [None]:
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False)
test_dl  = DataLoader(test_ds,  batch_size=bs, shuffle=False)

dataloaders = {
    "train": train_dl,
    "val": valid_dl,
    "test": test_dl
}

In [None]:
criterion = get_weighted_loss_with_logits(pos_weights.to(device), neg_weights.to(device))

In [None]:
model = pretrained_densenet121()
model = model.to(device)

In [None]:
model, fine_optimizer = fine_tune_setup(model, lr)

In [None]:
scheduler = EmptyScheduler()

In [None]:
model, history = fit(model, criterion, fine_optimizer, 
                     scheduler, dataloaders, model_name,
                     epochs, lr, patience=3, metric="loss")