In [11]:
import os
import pdb

from comet_ml import Experiment
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import h5py

import torch
from torch.nn import CrossEntropyLoss
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
import kornia.augmentation as K

from src.dataset import ClassificationDataset, NormalizeInstance, get_image_pair_filepaths
from src.models import UNet
from src.metrics import dice_loss, dice_score
from src.utils import create_canvas
from src.train import train_one_epoch, validate
import src.config

%load_ext autoreload
%autoreload 2 # 0: off, 2: on for all modules
# os.chdir('CompositionalNets/')
# sys.path.append('/project/6052161/mattlk/workplace/CompNet')

# Change the below directory depending on where the CHAOS dataset is stored
data_dir = src.config.directories['chaos']

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# experiment = Experiment(api_key="P5seMqEJjqZ8mDA7QYSuK3yUJ",
#                         project_name="chaos-liver-segmentation",
#                         workspace="matthew42", auto_metric_logging=False)

# Train U-Net on CHAOS for Liver Segmentation

In [93]:
%%time
params = {
    "lr": 0.0001,
    "batch_size": 16,
    "split_train_val": 0.8,
    "epochs": 45,
    "use_dice_loss": False,
    "cache": True,
    "random_seed": 42,
    "shuffle_data": True,
    "scheduler": "StepLR",
    "step_size": 15,
    "gamma": 0.75,
    "threshold": 0.5,
    "pretrained": True,
}

is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if is_cuda_available else "cpu")
input_images_dtype = torch.double
targets_dtype = torch.long
if is_cuda_available: torch.cuda.empty_cache()

cache_input_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(3),
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
])

cache_target_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.astype(np.uint8)),
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
#     transforms.Lambda(lambda x: x*255),
#     transforms.Lambda(lambda x: x.long()),
])

# input_transform = transforms.Compose([
#     K.RandomAffine(0, shear=(-5, 5)),
#     K.RandomHorizontalFlip(),
#     transforms.Lambda(lambda x: x.squeeze()),
# ])
input_transform = None

target_transform = transforms.Compose([
#     K.RandomAffine(0, shear=(-5, 5)),
#     K.RandomHorizontalFlip(),
#     transforms.Lambda(lambda x: x.squeeze()),
    transforms.Lambda(lambda x: x*255),
    transforms.Lambda(lambda x: x.long()),
])
target_transform = None

data_dir = src.config.directories['chaos']
hdf5_path = os.path.join(data_dir, 'train.hdf5')
with h5py.File(hdf5_path, 'r') as hf:
    images, targets = hf['images'][:8], hf['masks'][:8]

images = [cache_input_transform(im) for im in images]
targets = [cache_target_transform(t) for t in targets]

X_train, X_test, y_train, y_test = train_test_split(images, targets, 
                                                    train_size=params['split_train_val'],
                                                    random_state=params['random_seed'],
                                                    shuffle=params["shuffle_data"])
X_test = X_train
y_test = y_train

train_dataset = ClassificationDataset(X_train, y_train, 
                                      input_transform, target_transform)
val_dataset = ClassificationDataset(X_test, y_test,
                                    target_transform=transforms.Compose(
                                        [transforms.Lambda(lambda x: x*255), transforms.Lambda(lambda x: x.long())]))

num_train, num_val = len(train_dataset), len(val_dataset)
params['num_samples'] = num_train + num_val
params['target_transform'] = target_transform.__str__()
params['input_transform'] = input_transform.__str__()

train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'])
val_dataloader = DataLoader(val_dataset, batch_size=params['batch_size'])

# Instantiate model, optimizer, and criterion
unet = UNet(dice=params['use_dice_loss'], pretrained=params['pretrained'])
if is_cuda_available: unet = unet.to(device, dtype=input_images_dtype)

optimizer = optim.Adam(unet.parameters(), lr=params['lr'])
if params['scheduler'] == 'StepLR':
    scheduler = optim.lr_scheduler.StepLR(optimizer,
    step_size=params['step_size'], gamma=params['gamma'])
elif params['scheduler'] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# cross-entropy loss: weighting of negative vs positive pixels
loss_weight = torch.DoubleTensor([0.01, 0.99])
if is_cuda_available: loss_weight = loss_weight.to(device)
criterion = dice_loss if params['use_dice_loss'] else CrossEntropyLoss(weight=loss_weight,
reduction='mean')

image, target = train_dataset[0]
image = image.clone().permute(1, 2, 0).numpy()
target = target.clone().numpy()
img = create_canvas(image, target, show=False,
                   title1='Example Input', title2='Example Target')
img

KeyboardInterrupt: 

In [94]:
def validate(net, dataloader, epoch, device=None, input_dtype=torch.double,
    target_dtype=torch.long, use_dice_loss=False, experiment=None,
    batch_freq=50, epoch_freq=25, threshold=0.5, **kwargs):
    from src.metrics import jaccard_score, dice_score
    import torch.nn.functional as F
    """Gather validation metrics (Dice, Jaccard) on neural network
    """
    net.eval()
    torch.set_grad_enabled(False)
    dice_mean = torch.zeros((1), device=device)
    jaccard_mean = torch.zeros((1), device=device)

    for i, data in enumerate(dataloader):

        input_images, targets = data

        if device:
            input_images = input_images.to(device, input_dtype)
            targets = targets.to(device, target_dtype)

        outputs = net(input_images)

        if use_dice_loss:
            outputs = F.log_softmax(outputs, dim=1)
        else:
            outputs = F.softmax(outputs, dim=1)
            outputs = F.threshold(outputs[:, 1, :, :].unsqueeze(dim=1), threshold, 0)
            outputs = torch.round(outputs)

        score = dice_score(outputs, targets)
        dice_mean = dice_mean + (score - dice_mean) / (i + 1)
        score = jaccard_score(outputs, targets)
        jaccard_mean = jaccard_mean + (score - jaccard_mean) / (i + 1)

#         outputs, targets = outputs.data.cpu().numpy()*255, targets.data.cpu().numpy()*255
#         for idx, (out, gt) in enumerate(zip(outputs, targets)):
#             img = create_canvas(out, gt, show=False)
#             plt.figure(figsize=(10, 10))
#             plt.imshow(img)

    return dice_mean.item(), jaccard_mean.item()

In [94]:
%%time
# with experiment.train():
num_accumulated_steps = 1

print(f'Number of training images:\t{num_train}\nNumber of validation images:\t{num_val}')
for epoch in range(params['epochs']):

    unet, running_loss = train_one_epoch(unet, train_dataloader, optimizer,
                                         criterion, device=device,
                                         num_accumulated_steps=num_accumulated_steps, 
                                         **params)

    if params['use_dice_loss']:
        print(f'[Epoch {epoch+1:03d} Training]\tDice Loss:\t\t{running_loss:.4f}')
    else:
        print(f'[Epoch {epoch+1:03d} Training]\tCross-Entropy Loss:\t{running_loss:.4f}')
#     experiment.log_metric("Running Loss", running_loss, epoch=epoch, step=epoch, include_context=False)

    f1_mean, jaccard_mean = validate(unet, val_dataloader, epoch, device,
#                                      experiment=experiment, batch_freq=25,
                                     experiment=None, batch_freq=25,
                                     epoch_freq=25, **params)

    if params['scheduler'] == 'ReduceLROnPlateau':
        scheduler.step(f1_mean)
    else:
        scheduler.step()
    print(f'[Epoch {epoch+1:03d} Validation]\tAverage F1 Score:\t{f1_mean:.4f}\tAverage Jaccard/IoU:\t{jaccard_mean:.4f}\n')

#     experiment.log_metric('Validation Average F1 Score', f1_mean,
#                           epoch=epoch, include_context=False)
#     experiment.log_metric('Validation Average Jaccard/IoU', jaccard_mean,
#                           epoch=epoch, include_context=False)

# torch.save(unet.state_dict(), 'unet.pth')
# experiment.log_asset('unet.pth', copy_to_tmp=False)
# experiment.end()

KeyboardInterrupt: 