In [1]:
import os
import pdb

from comet_ml import Experiment
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.nn import CrossEntropyLoss
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF

from src.dataset import Chaos2DSegmentationDataset, NormalizeInstance, get_image_pair_filepaths
from src.models import UNet
from src.metrics import dice_loss, dice_score
from src.utils import create_canvas
from src.train import train_one_epoch, validate
import src.config

%load_ext autoreload
%autoreload 2 # 0: off, 2: on for all modules
# os.chdir('CompositionalNets/')
# sys.path.append('/project/6052161/mattlk/workplace/CompNet')

In [2]:
# Change the below directory depending on where the CHAOS dataset is stored
data_dir = src.config.directories['chaos']

In [3]:
# experiment = Experiment(api_key="P5seMqEJjqZ8mDA7QYSuK3yUJ",
#                         project_name="chaos-liver-segmentation",
#                         workspace="matthew42", auto_metric_logging=False)

# Train U-Net on CHAOS for Liver Segmentation

In [6]:
%%time
params = {
    "lr": 0.0001,
    "batch_size": 8,
    "split_train_val": 0.8,
    "epochs": 125,
    "use_dice_loss": False,
    "cache": True,
    "random_seed": 42,
    "shuffle_data": True,
    "scheduler": "StepLR",
    "step_size": 15,
    "gamma": 0.75,
    "threshold": 0.75,
    'weight_decay': 4e-3
}

is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if is_cuda_available else "cpu")
input_images_dtype = torch.double
targets_dtype = torch.long

cache_input_transform = transforms.Compose([
    NormalizeInstance(mean=1.0),
    transforms.Lambda(lambda x: x.astype(np.float32)),
    transforms.ToTensor(),
    Resize((256, 256)),
])

cache_gt_transform = transforms.Compose([
    transforms.ToTensor(),
    Resize((256, 256)),
])

input_transform = transforms.Compose([
#     transforms.RandomAffine(degrees=5, shear=5),
#     transforms.ToTensor()
])

gt_transform = transforms.Compose([
#     transforms.RandomAffine(degrees=5, shear=5),
#     transforms.ToTensor(),
    transforms.Lambda(lambda x: x*255),
    transforms.Lambda(lambda x: x.long()),
])

# Load data for training and validation
image_pair_filepaths = get_image_pair_filepaths(data_dir)[:4]
train_filepaths, val_filepaths = train_test_split(image_pair_filepaths,
                                                  train_size=params['split_train_val'],
                                                  random_state=params['random_seed'],
                                                  shuffle=params["shuffle_data"])
# train_filepaths, val_filepaths = image_pair_filepaths, image_pair_filepaths

train_dataset = Chaos2DSegmentationDataset(train_filepaths, input_transform=input_transform,
                                           gt_transform=gt_transform, cache=params['cache'],
                                           cache_input_transform=cache_input_transform,
                                           cache_gt_transform=cache_gt_transform,
                                           device=device)

val_dataset = Chaos2DSegmentationDataset(val_filepaths, input_transform=input_transform,
                                         gt_transform=gt_transform, cache=params['cache'],
                                         cache_input_transform=cache_input_transform,
                                         cache_gt_transform=cache_gt_transform,
                                         device=device)

num_train, num_val = len(train_dataset), len(val_dataset)
params['num_samples'] = num_train + num_val

train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'])
val_dataloader = DataLoader(val_dataset, batch_size=params['batch_size'])

# Instantiate model, optimizer, and criterion
torch.cuda.empty_cache()
unet = UNet(dice=params['use_dice_loss'])
if is_cuda_available: unet = unet.to(device, dtype=input_images_dtype)

optimizer = optim.Adam(unet.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])
if params['scheduler'] == 'StepLR': 
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=params['step_size'], gamma=params['gamma'])
elif params['scheduler'] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# cross-entropy loss: weighting of negative vs positive pixels
loss_weight = torch.DoubleTensor([0.01, 0.99])
if is_cuda_available: loss_weight = loss_weight.to(device)
criterion = dice_loss if params['use_dice_loss'] else CrossEntropyLoss(weight=loss_weight,
                                                                       reduction='mean')

# experiment.log_parameters(params)

CPU times: user 2.55 s, sys: 2.2 s, total: 4.75 s
Wall time: 5.44 s


In [8]:
%%time
# with experiment.train():
num_accumulated_steps = 128 // params['batch_size']

print(f'Number of training images:\t{num_train}\nNumber of validation images:\t{num_val}')
for epoch in range(params['epochs']):

    unet, running_loss = train_one_epoch(unet, train_dataloader, optimizer,
                                         criterion, 
                                         num_accumulated_steps=num_accumulated_steps, 
                                         **params)

    if params['use_dice_loss']:
        print(f'[Epoch {epoch+1:03d} Training]\tDice Loss:\t\t{running_loss:.4f}')
    else:
        print(f'[Epoch {epoch+1:03d} Training]\tCross-Entropy Loss:\t{running_loss:.4f}')
#     experiment.log_metric("Running Loss", running_loss, epoch=epoch, step=epoch, include_context=False)

    f1_mean, jaccard_mean = validate(unet, val_dataloader, epoch, device,
#                                      experiment=experiment, batch_freq=25,
                                     experiment=None, batch_freq=25,
                                     epoch_freq=25, **params)

    if params['scheduler'] == 'ReduceLROnPlateau':
        scheduler.step(f1_mean)
    else:
        scheduler.step()
    print(f'[Epoch {epoch+1:03d} Validation]\tAverage F1 Score:\t{f1_mean:.4f}\tAverage Jaccard/IoU:\t{jaccard_mean:.4f}\n')

#     experiment.log_metric('Validation Average F1 Score', f1_mean,
#                           epoch=epoch, include_context=False)
#     experiment.log_metric('Validation Average Jaccard/IoU', jaccard_mean,
#                           epoch=epoch, include_context=False)

# torch.save(unet.state_dict(), 'unet.pth')
# experiment.log_asset('unet.pth', copy_to_tmp=False)
# experiment.end()

Number of training images:	3
Number of validation images:	1
[Epoch 001 Training]	Cross-Entropy Loss:	0.0427
[Epoch 001 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 002 Training]	Cross-Entropy Loss:	0.0394
[Epoch 002 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 003 Training]	Cross-Entropy Loss:	0.0365
[Epoch 003 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 004 Training]	Cross-Entropy Loss:	0.0339
[Epoch 004 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 005 Training]	Cross-Entropy Loss:	0.0316
[Epoch 005 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 006 Training]	Cross-Entropy Loss:	0.0295
[Epoch 006 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 007 Training]	Cross-Entropy Loss:	0.0277
[Epoch 007 Validation]	Average F1 Score:	1.0000	Average Jaccard/IoU:	1.0000

[Epoch 008 Training]	Cross-Entropy Loss:	0.0260
[Epoch 008 Valida

KeyboardInterrupt: 