In [1]:
import sys
# sys.path.append('/project/6052161/mattlk/workplace/CompNet')
import os
# os.chdir('CompositionalNets/')
import pdb

import numpy as np
from comet_ml import Experiment
from mlflow import log_metric, log_param, log_artifact
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from dataset import Chaos2DSegmentationDataset, NormalizeInstance, get_image_pair_filepaths
from models import UNet
from loss import dice as dice_loss
%load_ext autoreload
%autoreload 2 # 0: off, 2: on for all modules

In [2]:
# Change the below directory depending on where the CHAOS dataset is stored
data_dir = os.path.join('CompositionalNets', 'data', 'chaos')

In [3]:
experiment = Experiment(api_key="P5seMqEJjqZ8mDA7QYSuK3yUJ",
                        project_name="chaos-liver-segmentation", 
                        workspace="matthew42")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/matthew42/chaos-liver-segmentation/2b4d2a1e91b14653aba01a0d375ee809



# Train U-Net on CHAOS for Liver Segmentation

In [None]:
%%time
parameters = {
    "lr": 0.001,
    "batch_size": 2,
    "split_train_val": 0.8,
    "low_lr_epoch": 80,
    "epochs": 3,
    "use_dice_loss": False
}
experiment.log_parameters(parameters)

lr = parameters['lr']
batch_size = parameters['batch_size']
split_train_val = parameters['split_train_val']
low_lr_epoch = parameters['low_lr_epoch']
epochs = parameters['epochs']
use_dice_loss = parameters['use_dice_loss']
num_samples = 1000

is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if is_cuda_available else "cpu")
input_images_dtype = torch.double
gt_images_dtype = torch.long
cache_data = True

input_transform = transforms.Compose([
    NormalizeInstance(),
    transforms.ToTensor()
])

# Load data for training and validation
image_pair_filepaths = get_image_pair_filepaths(data_dir)
train_filepaths = image_pair_filepaths[:int(len(image_pair_filepaths)*split_train_val)]
val_filepaths = image_pair_filepaths[int(len(image_pair_filepaths)*split_train_val):]

train_dataset = Chaos2DSegmentationDataset(train_filepaths, input_transform=input_transform, cache=cache_data, device=device)
val_dataset = Chaos2DSegmentationDataset(val_filepaths, input_transform=input_transform, cache=cache_data, device=device)
print(f'Number of training images: {len(train_dataset)}\nNumber of validation images: {len(val_dataset)}')

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

# Instantiate model, optimizer, and criterion
torch.cuda.empty_cache()
unet = UNet(dice=use_dice_loss)
# unet = UNet(in_channels=1, out_channels=1, padding=0)
if is_cuda_available: unet = unet.to(device, dtype=input_images_dtype)

optimizer = optim.Adam(unet.parameters(), lr=lr)

# cross-entropy loss: weighting of negative vs positive pixels
loss_weight = torch.DoubleTensor([0.01, 0.99])
if is_cuda_available: loss_weight = loss_weight.to(device)
criterion = dice_loss if use_dice_loss else CrossEntropyLoss(weight=loss_weight)

In [None]:
%%time
with experiment.train():
    for epoch in tqdm(range(epochs), desc=f'Training {epochs} epochs'):

        running_loss = 0.0
        unet.train()

        for i, data in enumerate(train_dataloader):

            input_images, gt_images = data

            if is_cuda_available:
                input_images = input_images.to(device, dtype=input_images_dtype)
                gt_images = gt_images.to(device, dtype=gt_images_dtype)

    #         pdb.set_trace()
            outputs = unet(input_images)

            if use_dice_loss:
                outputs = outputs[:,1,:,:].unsqueeze(dim=1)
                loss = criterion(outputs, gt_images)
            else:
                gt_images = gt_images.squeeze(dim=1)
                loss = criterion(outputs, gt_images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            experiment.log_metric('Loss', loss.item(), step=i, epoch=epoch)
        if use_dice_loss:
            print(f'[Epoch {epoch+1:03d}] Training Dice Loss: {running_loss/(i+1):.4f}')
        else:
            print(f'[Epoch {epoch+1:03d}] Training Cross-Entropy Loss: {running_loss/(i+1):.4f}')
        experiment.log_metric(f"{'Dice' if use_dice_loss else 'Cross-Entryopy'} Running Loss", running_loss, epoch=epoch)

        unet.eval()
        val_accuracy = 0.0
        all_accuracy = []
        all_dice = []
        all_outputs = []

        for i, data in enumerate(val_dataloader):
            accuracy = 0.0
            intersect = 0.0
            union = 0.0

            input_images, gt_images = data
            if is_cuda_available:
                input_images = input_images.to(device, dtype=input_images_dtype)
                gt_images = gt_images.to(device, dtype=gt_images_dtype)
            outputs = unet(input_images)

            # log softmax into softmax
            if not use_dice_loss: outputs = outputs.exp()

            # round outputs to either 0 or 1
            outputs = outputs[:, 1, :, :].unsqueeze(dim=1).round()

            # accuracy
            outputs, gt_images = outputs.data.cpu().numpy(), gt_images.data.cpu().numpy()
            accuracy += (outputs == gt_images).sum() / float(outputs.size)

            # dice
            intersect += (outputs+gt_images==2).sum()
            union += np.sum(outputs) + np.sum(gt_images)

            all_accuracy.append(accuracy / float(i+1))
            all_dice.append(1 - (2 * intersect + 1e-5) / (union + 1e-5))

            all_outputs.extend([out.permute() for out in outputs])

        print(f'[Epoch {epoch+1:03d}] Validation Accuracy: {np.mean(all_accuracy)}. Validation Dice Score: {np.mean(all_dice)}')

        experiment.log_metrics({
            'Validation Accuracy': np.mean(all_accuracy),
            'Validation Dice Score': np.mean(all_dice)
        }, epoch=epoch)
experiment.close()