In [1]:
import os
import pdb
import random

from comet_ml import Experiment
import numpy as np
import matplotlib.pyplot as plt
import h5py
from tqdm import tqdm
from datetime import datetime

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
import kornia.augmentation as K

from src.dataset import Resize, Hdf5SegmentationDataset
from src.models import UNet
from src.metrics import dice_loss, dice_score
from src.utils import create_canvas
from src.train import train_one_epoch, validate
from src.config import directories

%load_ext autoreload
%autoreload 2 # 0: off, 2: on for all modules
# os.chdir('CompositionalNets/')
# sys.path.append('/project/6052161/mattlk/workplace/CompNet')

data_dir = directories['pulmonary_cxr_abnormalities']

In [2]:
%%time
params = {
        "lr": 0.0001,
        "batch_size": 8,
        "split_train_val": 0.8,
        "epochs": 50,
        "use_dice_loss": False,
        "cache": True,
        "random_seed": 42,
        "shuffle_data": True,
        "scheduler": "StepLR",
        "step_size": 15,
        "gamma": 0.75,
        "threshold": 0.5,
        "pretrained": True,
#         "change_loss_weights_at": 50
    }

cache_input_transform = transforms.Compose([
    transforms.ToTensor(),
    Resize((256, 256)),
])

cache_target_transform = transforms.Compose([
    transforms.ToTensor(),
    Resize((256, 256)),
#     transforms.Lambda(lambda x: x.squeeze()),
#     transforms.Lambda(lambda x: x*255),
#     transforms.Lambda(lambda x: x.long()),
])

input_transform = transforms.Compose([
#     transforms.ToTensor(),
#     Resize((256, 256)),
    K.RandomAffine(0, shear=(-5, 5)),
    K.RandomHorizontalFlip(),
    K.CenterCrop(224),
    transforms.Lambda(lambda x: x.squeeze()),
])
target_transform = transforms.Compose([
#     transforms.ToTensor(),
#     Resize((256, 256)),
    K.RandomAffine(0, shear=(-5, 5)),
    K.RandomHorizontalFlip(),
    K.CenterCrop(224),
    transforms.Lambda(lambda x: x.squeeze()),
    transforms.Lambda(lambda x: x*255),
    transforms.Lambda(lambda x: x.long()),
])
val_target_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.squeeze()),
    transforms.Lambda(lambda x: x*255),
    transforms.Lambda(lambda x: x.long()),
])

is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if is_cuda_available else "cpu")
input_dtype = torch.double
target_dtype = torch.long
if is_cuda_available: torch.cuda.empty_cache()

hf_fp = os.path.join(data_dir, 'train.hdf5')

train_dataset = Hdf5SegmentationDataset(hf_fp, 'shenzhen/healthy/cxr', 'shenzhen/healthy/masks',
                                        input_transform=input_transform, target_transform=target_transform,
                                        cache_input_transform=cache_input_transform,
                                        cache_target_transform=cache_target_transform,
                                        cache=True, distribution_name='shenzhen/healthy/target_distribution',
                                        target_count_name='shenzhen/healthy/target_count')

hf_fp = os.path.join(data_dir, 'test.hdf5')
val_dataset = Hdf5SegmentationDataset(hf_fp, 'shenzhen/healthy/cxr', 'shenzhen/healthy/masks',
                                      target_transform=val_target_transform,
                                      cache_input_transform=cache_input_transform,
                                      cache_target_transform=cache_target_transform,
                                      cache=True)

torch.manual_seed(params['random_seed'])
train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], 
                              pin_memory=is_cuda_available, shuffle=True,
                              num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=params['batch_size'], 
                            pin_memory=is_cuda_available, shuffle=True,
                            num_workers=4)

bias = train_dataset.target_count[1]/train_dataset.target_count[0]
unet = UNet(dice=params['use_dice_loss'], pretrained=params['pretrained'])

optimizer = torch.optim.Adam(unet.parameters(), lr=params['lr'])
if params['scheduler'] == 'StepLR': 
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=params['step_size'], gamma=params['gamma'])
elif params['scheduler'] == 'ReduceLROnPlateau':
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# cross-entropy loss: weighting of negative vs positive pixels
loss_weight = torch.DoubleTensor(1 - train_dataset.target_distribution)
# loss_weight = torch.DoubleTensor([0.01, 0.99])
params['loss_weights'] = loss_weight.numpy()
# weight_change = torch.from_numpy(train_dataset.target_distribution - loss_weight.numpy()) / (params["change_loss_weights_at"] - params['epochs'])

if is_cuda_available: 
    loss_weight = loss_weight.to(device)
#     weight_change = weight_change.to(device)
    unet = unet.to(device, dtype=input_dtype)
    
criterion = dice_loss if params['use_dice_loss'] else CrossEntropyLoss(weight=loss_weight,
                                                                       reduction='mean')

CPU times: user 16.4 s, sys: 5.07 s, total: 21.5 s
Wall time: 26.2 s


In [3]:
experiment = Experiment(api_key="P5seMqEJjqZ8mDA7QYSuK3yUJ",
                        project_name="lung-segmentation", workspace="matthew42")

num_train, num_val = len(train_dataset), len(val_dataset)
params['num_samples'] = num_train + num_val
params['target_transform'] = target_transform.__str__()
params['input_transform'] = input_transform.__str__()

experiment.log_parameters(params)
num_accumulated_steps = 128 // params['batch_size']

with experiment.train():
    
    print(f'Number of training images:\t{num_train}\nNumber of validation images:\t{num_val}')
    
    for epoch in range(params['epochs']):

        unet, running_loss = train_one_epoch(unet, train_dataloader, optimizer,
                                             criterion, num_accumulated_steps=num_accumulated_steps,
                                             device=device, input_dtype=input_dtype, target_dtype=target_dtype,
                                             **params)

        if params['use_dice_loss']:
            print(f'[Epoch {epoch+1:03d} Training]\tDice Loss:\t\t{running_loss:.4f}')
        else:
            print(f'[Epoch {epoch+1:03d} Training]\tCross-Entropy Loss:\t{running_loss:.4f}')
            experiment.log_metric("Running Loss", running_loss, epoch=epoch, step=epoch, include_context=False)
        
        f1_mean, jaccard_mean = validate(unet, val_dataloader, epoch, device,
                                         batch_freq=3, epoch_freq=10,
                                         experiment=experiment, **params)

        if params['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(f1_mean)
        else:
            scheduler.step()
        print(f'[Epoch {epoch+1:03d} Validation]\tAverage F1 Score:\t{f1_mean:.4f}\tAverage Jaccard/IoU:\t{jaccard_mean:.4f}\n')

        experiment.log_metric('Validation Average F1 Score', f1_mean,
                              epoch=epoch, include_context=False)
        experiment.log_metric('Validation Average Jaccard/IoU', jaccard_mean,
                              epoch=epoch, include_context=False)
        
#         if epoch >= params["change_loss_weights_at"]:
#             criterion.weight = criterion.weight + weight_change

date_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
filepath = os.path.join(directories['checkpoints'], f'unet_lung_{date_time}.pth')
# torch.save(unet.state_dict(), filepath)
torch.save({
#     'epoch': epoch,
    'model_state_dict': unet.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'scheduler_state_dict': scheduler.state_dict(),
    }, filepath)
experiment.log_asset(filepath, copy_to_tmp=False)
experiment.end()

COMET INFO: old comet version (3.1.14) detected. current: 3.1.15 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/matthew42/lung-segmentation/5e4b5944b5d643799a84e5ff22af56c5



Number of training images:	223
Number of validation images:	56
[Epoch 001 Training]	Cross-Entropy Loss:	1.8398
[Epoch 001 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 002 Training]	Cross-Entropy Loss:	1.8328
[Epoch 002 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 003 Training]	Cross-Entropy Loss:	1.8223
[Epoch 003 Validation]	Average F1 Score:	0.0000	Average Jaccard/IoU:	0.0000

[Epoch 004 Training]	Cross-Entropy Loss:	1.8053
[Epoch 004 Validation]	Average F1 Score:	0.0115	Average Jaccard/IoU:	0.0058

[Epoch 005 Training]	Cross-Entropy Loss:	1.7891
[Epoch 005 Validation]	Average F1 Score:	0.4659	Average Jaccard/IoU:	0.2329

[Epoch 006 Training]	Cross-Entropy Loss:	1.7479
[Epoch 006 Validation]	Average F1 Score:	0.3597	Average Jaccard/IoU:	0.1798

[Epoch 007 Training]	Cross-Entropy Loss:	1.6055
[Epoch 007 Validation]	Average F1 Score:	0.4341	Average Jaccard/IoU:	0.2170

[Epoch 008 Training]	Cross-Entropy Loss:	1.5024
[Epoch 008 Val

COMET INFO: Uploading stats to Comet before program termination (may take several seconds)
COMET INFO: Waiting for completion of the file uploads (may take several seconds)
COMET INFO: Still uploading
COMET INFO: Still uploading
COMET INFO: Still uploading
COMET INFO: Still uploading


In [None]:
unet.train()
torch.set_grad_enabled(True)
optimizer.zero_grad()
running_loss = 0.0
use_dice_loss = False

for i, data in tqdm(enumerate(train_dataloader)):

    input_images, targets = data

    if device:
        input_images = input_images.to(device, input_dtype)
        targets = targets.to(device, target_dtype)
    pdb.set_trace()
    outputs = unet(input_images)

    if use_dice_loss:
        outputs = F.log_softmax(outputs, dim=1)
        outputs = outputs[:, 1, :, :].unsqueeze(dim=1)
        loss = criterion(outputs, targets)
    else:
        targets = targets.squeeze(dim=1)
        loss = criterion(outputs, targets)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    running_loss += loss.detach()

In [None]:
num_accumulated_steps = 128 // params['batch_size']

unet.train()
torch.set_grad_enabled(True)
optimizer.zero_grad()
running_loss = 0.0
use_dice_loss = False

for i, data in enumerate(train_dataloader):

    input_images, targets = data

    if device:
        input_images = input_images.to(device, input_dtype)
        targets = targets.to(device, target_dtype)
    pdb.set_trace()
    outputs = unet(input_images)

    if use_dice_loss:
        outputs = F.log_softmax(outputs, dim=1)
        outputs = outputs[:, 1, :, :].unsqueeze(dim=1)
        loss = criterion(outputs, targets)
    else:
        targets = targets.squeeze(dim=1)
        loss = criterion(outputs, targets)

    loss /= num_accumulated_steps
    loss.backward()
    running_loss += loss.detach().cpu().numpy()
    print(running_loss)

    if i % num_accumulated_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        running_loss /= num_accumulated_steps

if i % num_accumulated_steps != 0:
    optimizer.step()
    running_loss /= (num_accumulated_steps - i % num_accumulated_steps)

In [None]:
image, target = train_dataset[0]
cc = transforms.Compose([
    K.RandomAffine(0, shear=(-5, 5)),
    K.RandomHorizontalFlip(),
    K.CenterCrop(230),
])
seed = np.random.randint(2147483647)
random.seed(seed)
torch.manual_seed(seed)
image = cc(image)
random.seed(seed)
torch.manual_seed(seed)
target = cc(target)
plt.subplot(1, 2, 1)
plt.imshow(image.cpu().squeeze().permute(1, 2, 0).numpy())
plt.subplot(1, 2, 2)
plt.imshow(target.cpu().squeeze().numpy())