In [1]:
import os
import pdb

from comet_ml import Experiment
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.nn import CrossEntropyLoss
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF

from src.dataset import Chaos2DSegmentationDataset, NormalizeInstance, get_image_pair_filepaths
from src.models import UNet
from src.metrics import dice_loss, dice_score
from src.utils import create_canvas
from src.train import train_one_epoch, validate
from src.config import directories

%load_ext autoreload
%autoreload 2 # 0: off, 2: on for all modules

In [2]:
# Change the below directory depending on where the CHAOS dataset is stored
data_dir = directories['rsna']

In [3]:
class Resize(object):
    def __init__(self, size):
        self.size = size
        self.transform = transforms.Compose([
            transforms.Lambda(lambda x: x.unsqueeze(0)),
            transforms.Lambda(lambda x: F.interpolate(x, size=self.size)),
            transforms.Lambda(lambda x: x.squeeze(0))
        ])

    def __call__(self, image):
        return self.transform(image)

In [8]:
%%time
params = {
    "lr": 0.0001,
    "batch_size": 16,
    "split_train_val": 0.8,
    "epochs": 125,
    "use_dice_loss": False,
    "cache": True,
    "random_seed": 42,
    "shuffle_data": True,
    "scheduler": "StepLR",
    "step_size": 15,
    "gamma": 0.75,
    "threshold": 0.75
}

is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if is_cuda_available else "cpu")
input_images_dtype = torch.double
targets_dtype = torch.long

cache_input_transform = transforms.Compose([
    NormalizeInstance(mean=1.0),
    transforms.Lambda(lambda x: x.astype(np.float32)),
    transforms.ToTensor(),
    Resize((224, 224)),
])

cache_gt_transform = transforms.Compose([
    transforms.ToTensor(),
    Resize((224, 224)),
])

input_transform = transforms.Compose([
#     transforms.RandomAffine(degrees=5, shear=5),
#     transforms.ToTensor()
])

gt_transform = transforms.Compose([
#     transforms.RandomAffine(degrees=5, shear=5),
#     transforms.ToTensor(),
    transforms.Lambda(lambda x: x*255),
    transforms.Lambda(lambda x: x.long()),
])

# Load data for training and validation
image_pair_filepaths = get_image_pair_filepaths(data_dir)[:200]
train_filepaths, val_filepaths = train_test_split(image_pair_filepaths,
                                                  train_size=params['split_train_val'],
                                                  random_state=params['random_seed'],
                                                  shuffle=params["shuffle_data"])
# train_filepaths, val_filepaths = image_pair_filepaths, image_pair_filepaths

train_dataset = Chaos2DSegmentationDataset(train_filepaths, input_transform=input_transform,
                                           gt_transform=gt_transform, cache=params['cache'],
                                           cache_input_transform=cache_input_transform,
                                           cache_gt_transform=cache_gt_transform,
                                           device=device)

val_dataset = Chaos2DSegmentationDataset(val_filepaths, input_transform=input_transform,
                                         gt_transform=gt_transform, cache=params['cache'],
                                         cache_input_transform=cache_input_transform,
                                         cache_gt_transform=cache_gt_transform,
                                         device=device)

num_train, num_val = len(train_dataset), len(val_dataset)
params['num_samples'] = num_train + num_val

train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'])
val_dataloader = DataLoader(val_dataset, batch_size=params['batch_size'])

# Instantiate model, optimizer, and criterion
torch.cuda.empty_cache()
unet = UNet(dice=params['use_dice_loss'])
if is_cuda_available: unet = unet.to(device, dtype=input_images_dtype)

optimizer = optim.Adam(unet.parameters(), lr=params['lr'])
if params['scheduler'] == 'StepLR': 
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=params['step_size'], gamma=params['gamma'])
elif params['scheduler'] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# cross-entropy loss: weighting of negative vs positive pixels
loss_weight = torch.DoubleTensor([0.01, 0.99])
if is_cuda_available: loss_weight = loss_weight.to(device)
criterion = dice_loss if params['use_dice_loss'] else CrossEntropyLoss(weight=loss_weight,
                                                                       reduction='mean')

experiment.log_parameters(params)

unet.unet_model.UNet