In [1]:
import sys

import os
import pdb
import warnings

from comet_ml import Experiment
import numpy as np
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import jaccard_score, f1_score

import torch
from torch.nn import CrossEntropyLoss
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms

from dataset import Chaos2DSegmentationDataset, NormalizeInstance, get_image_pair_filepaths
from models import UNet
from metrics import dice_loss, dice_score
from utils import create_canvas
from train import train_one_epoch, validate

%load_ext autoreload
%autoreload 2 # 0: off, 2: on for all modules
# os.chdir('CompositionalNets/')
# sys.path.append('/project/6052161/mattlk/workplace/CompNet')

In [2]:
# Change the below directory depending on where the CHAOS dataset is stored
data_dir = os.path.join('CompositionalNets', 'data', 'chaos')

In [3]:
experiment = Experiment(api_key="P5seMqEJjqZ8mDA7QYSuK3yUJ",
                        project_name="chaos-liver-segmentation",
                        workspace="matthew42", auto_metric_logging=False)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/matthew42/chaos-liver-segmentation/8f7699ea2ab6422b8523ed26d373eb8b



# Train U-Net on CHAOS for Liver Segmentation

In [4]:
%%time
params = {
    "lr": 0.0001,
    "batch_size": 16,
    "split_train_val": 0.8,
    "epochs": 125,
    "use_dice_loss": False,
    "cache": True,
    "random_seed": 42,
    "shuffle_data": True,
    "scheduler": "StepLR",
    "step_size": 15,
    "gamma": 0.75,
    "threshold": 0.8
}

is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if is_cuda_available else "cpu")
input_images_dtype = torch.double
targets_dtype = torch.long

cache_input_transform = transforms.Compose([
    NormalizeInstance(mean=255.0),
    transforms.Lambda(lambda x: x.astype(np.uint8)),
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
#     transforms.ToTensor()
])

cache_gt_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.astype(np.uint8)),
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
#     transforms.ToTensor()
])

input_transform = transforms.Compose([
#     transforms.RandomAffine(degrees=5, shear=5),
    transforms.ToTensor()
])

gt_transform = transforms.Compose([
#     transforms.RandomAffine(degrees=5, shear=5),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x*255),
    transforms.Lambda(lambda x: x.long()),
])

# Load data for training and validation
image_pair_filepaths = get_image_pair_filepaths(data_dir)[:200]
train_filepaths, val_filepaths = train_test_split(image_pair_filepaths,
                                                  train_size=params['split_train_val'],
                                                  random_state=params['random_seed'],
                                                  shuffle=params["shuffle_data"])
# train_filepaths, val_filepaths = image_pair_filepaths, image_pair_filepaths

train_dataset = Chaos2DSegmentationDataset(train_filepaths, input_transform=input_transform,
                                           gt_transform=gt_transform, cache=params['cache'],
                                           cache_input_transform=cache_input_transform,
                                           cache_gt_transform=cache_gt_transform,
                                           device=device)

val_dataset = Chaos2DSegmentationDataset(val_filepaths, input_transform=input_transform,
                                         gt_transform=gt_transform, cache=params['cache'],
                                         cache_input_transform=cache_input_transform,
                                         cache_gt_transform=cache_gt_transform,
                                         device=device)

num_train, num_val = len(train_dataset), len(val_dataset)
params['num_samples'] = num_train + num_val

train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'])
val_dataloader = DataLoader(val_dataset, batch_size=params['batch_size'])

# Instantiate model, optimizer, and criterion
torch.cuda.empty_cache()
unet = UNet(dice=params['use_dice_loss'])
if is_cuda_available: unet = unet.to(device, dtype=input_images_dtype)

optimizer = optim.Adam(unet.parameters(), lr=params['lr'])
if params['scheduler'] == 'StepLR': 
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=params['step_size'], gamma=params['gamma'])
elif params['scheduler'] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# cross-entropy loss: weighting of negative vs positive pixels
loss_weight = torch.DoubleTensor([0.01, 0.99])
if is_cuda_available: loss_weight = loss_weight.to(device)
criterion = dice_loss if params['use_dice_loss'] else CrossEntropyLoss(weight=loss_weight,
                                                                       reduction='mean')

experiment.log_parameters(params)
num_accumulated_steps = 128 // params['batch_size']

CPU times: user 4.58 s, sys: 1.46 s, total: 6.04 s
Wall time: 14.3 s


In [5]:
%%time
with experiment.train():

    print(f'Number of training images:\t{num_train}\nNumber of validation images:\t{num_val}')
    for epoch in range(params['epochs']):

        unet, running_loss = train_one_epoch(unet, train_dataloader, optimizer,
                                             criterion, 
                                             num_accumulated_steps=num_accumulated_steps, 
                                             **params)

        if params['use_dice_loss']:
            print(f'[Epoch {epoch:03d} Training]\tDice Loss:\t\t{running_loss:.4f}')
        else:
            print(f'[Epoch {epoch:03d} Training]\tCross-Entropy Loss:\t{running_loss:.4f}')
        experiment.log_metric("Running Loss", running_loss, epoch=epoch, step=epoch, include_context=False)
        
        f1_mean, jaccard_mean = validate(unet, val_dataloader, epoch, device,
                                         experiment=experiment, batch_freq=25,
                                         epoch_freq=25, **params)

        if params['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(f1_mean)
        else:
            scheduler.step()
        print(f'[Epoch {epoch:03d} Validation]\tAverage F1 Score:\t{f1_mean:.4f}\tAverage Jaccard/IoU:\t{jaccard_mean:.4f}\n')

        experiment.log_metric('Validation Average F1 Score', f1_mean,
                              epoch=epoch, include_context=False)
        experiment.log_metric('Validation Average Jaccard/IoU', jaccard_mean,
                              epoch=epoch, include_context=False)

torch.save(unet.state_dict(), 'unet.pth')
experiment.log_asset('unet.pth', copy_to_tmp=False)
experiment.end()

Number of training images:	160
Number of validation images:	40
[Epoch 001 Training]	Cross-Entropy Loss:	7.0775
[Epoch 001 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 002 Training]	Cross-Entropy Loss:	7.0222
[Epoch 002 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 003 Training]	Cross-Entropy Loss:	6.9631
[Epoch 003 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 004 Training]	Cross-Entropy Loss:	6.9029
[Epoch 004 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 005 Training]	Cross-Entropy Loss:	6.8381
[Epoch 005 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 006 Training]	Cross-Entropy Loss:	6.7644
[Epoch 006 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 007 Training]	Cross-Entropy Loss:	6.6751
[Epoch 007 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 008 Training]	Cross-Entropy Loss:	6.5650
[Epoch 008 Val

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/matthew42/chaos-liver-segmentation/8f7699ea2ab6422b8523ed26d373eb8b
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     train_Running Loss [125]                   : (1.1855929383677806, 7.077512405754706)
COMET INFO:     train_Validation Average F1 Score [125]    : (2.05681180420747e-10, 0.436721533536911)
COMET INFO:     train_Validation Average Jaccard/IoU [125] : (2.05681180420747e-10, 0.2183607667684555)
COMET INFO:   Parameters:
COMET INFO:     batch_size      : 16
COMET INFO:     cache           : True
COMET INFO:     epochs          : 125
COMET INFO:     gamma           : 0.75
COMET INFO:     lr              : 0.0001
COMET INFO:     num_samples     : 200
COMET INFO:     random_seed     : 42
COMET INFO:     scheduler       : StepLR
COM

[Epoch 125 Validation]	Average F1 Score:	0.4181	Average Jaccard/IoU:	0.2090



COMET INFO: Uploading stats to Comet before program termination (may take several seconds)


CPU times: user 1h 51min 33s, sys: 57min 54s, total: 2h 49min 27s
Wall time: 2h 53min
