In [1]:
# !pip install -q git+https://github.com/matjesg/deepflash2.git

In [2]:
# !pip install zarr

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

In [5]:
# imports
import numpy as np, pandas as pd, segmentation_models_pytorch as smp
import albumentations as alb
import torch
import torch.nn as nn

import cv2, json

import torch.utils.data as D

from tqdm.notebook import tqdm

In [6]:
from config.global_vars import *
from datasets.hubdataset import HubDataset

In [7]:
# Model
def create_model():
    model = smp.Unet(encoder_name=ENCODER_NAME, 
                     encoder_weights=ENCODER_WEIGHTS,
                     activation=None,
                     in_channels=CHANNELS, 
                     classes=2)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model.cuda()
    model.train()
    return model

model = create_model()

### Configuration

In [8]:
dice_loss = smp.losses.DiceLoss(mode='multilabel', from_logits=True)
smp.losses.DiceLoss.__name__ = 'Dice Loss'
dice_loss.__name__ = 'Dice Loss'

jaccard_loss = smp.losses.JaccardLoss(mode='multilabel', from_logits=True)
smp.losses.JaccardLoss.__name__ = 'Jaccard Loss'
jaccard_loss.__name__ = 'Jaccard Loss'

cross_entropy_loss = torch.nn.CrossEntropyLoss()

LOSS_FACTOR = 0.2
def dice_ce_loss(y_pred, y):
    y = y.long()
    y_target = y.sum(1)
    return dice_loss(y_pred, y) * LOSS_FACTOR + cross_entropy_loss(y_pred, y_target) * (1 - LOSS_FACTOR)
#     return dice_loss(y_pred.sigmoid(), y)

In [9]:
class CONFIG():
    
    # data paths
    data_path = Path('/home/jupyter/data_2/')
    data_path_zarr = Path('/home/jupyter/train_scale2')
    mask_preproc_dir = '/home/jupyter/masks_scale2'
    
    # deepflash2 dataset
    # scale = 1.5 # data is already downscaled to 2, so absulute downscale is 3
    scale = 1 # data is already downscaled to 2, so absulute downscale is 3
    tile_shape = (TILE_SHAPE, TILE_SHAPE)
    padding = (0,0) # Border overlap for prediction
    n_jobs = NUM_WORKERS
    sample_mult = 300 # Sample 100 tiles from each image, per epoch
    val_length = 500 # Randomly sample 500 validation tiles
    stats = np.array([0.61561477, 0.5179343 , 0.64067212]), np.array([0.2915353 , 0.31549066, 0.28647661])
    
    # deepflash2 augmentation options
    zoom_sigma = 0.1
    flip = True
    max_rotation = 360
    deformation_grid_size = (150,150)
    deformation_magnitude = (10,10)

    # pytorch model (segmentation_models_pytorch)
    encoder_name = ENCODER_NAME
    encoder_weights = ENCODER_WEIGHTS
    in_channels = 3
    classes = 2
    
    # fastai Learner 
    mixed_precision_training = True
    batch_size = 5
    weight_decay = 0.01
    loss_func = dice_loss
#     metrics = [Iou(), Dice_f1()]
    max_learning_rate = 1e-3
    epochs = 12
    optimizer = torch.optim.AdamW(model.parameters(), lr=max_learning_rate, weight_decay=weight_decay)
    model = model
    arch = 'unet'
    
    patience = 8
    
cfg = CONFIG()

In [10]:
# Albumentations augmentations
# Inspired by https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter
# deepflash2 augmentations are only affine transformations
tfms = alb.OneOf([
    alb.HueSaturationValue(10,15,10),
    alb.CLAHE(clip_limit=2),
    alb.RandomBrightnessContrast(),
    alb.OneOf([
        alb.MotionBlur(p=0.2),
        alb.MedianBlur(blur_limit=3, p=0.1),
        alb.Blur(blur_limit=3, p=0.1),
    ], p=0.2)
], p=0.3)

### Datasets

In [11]:
root_dir = cfg.data_path
slices_path = SLICES_PATH
transform = alb.Compose([
        alb.Resize(TILE_SHAPE, TILE_SHAPE, p=1.0),
        alb.HorizontalFlip(),
        alb.VerticalFlip(),
        alb.RandomRotate90(),
        alb.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=20, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),
        alb.OneOf([
            alb.OpticalDistortion(p=0.4),
            alb.GridDistortion(p=.1, border_mode=cv2.BORDER_REFLECT),
            alb.IAAPiecewiseAffine(p=0.4),
        ], p=0.3),
        alb.OneOf([
            alb.MotionBlur(p=0.2),
            alb.MedianBlur(blur_limit=3, p=0.1),
            alb.Blur(blur_limit=3, p=0.1),
        ], p=0.3),
        alb.OneOf([
            alb.HueSaturationValue(10,15,10),
            alb.CLAHE(clip_limit=3),
            alb.RandomBrightnessContrast(),
            alb.RandomGamma()
        ], p=0.5)
    ], p=1.0)

valid_transform = alb.Compose([
        alb.Resize(TILE_SHAPE, TILE_SHAPE, p=1.0),
        alb.HorizontalFlip(),
        alb.VerticalFlip(),
        alb.RandomRotate90()
    ], p=1.0)
window = WINDOW
overlap = OVERLAP
threshold = THRESHOLD
ds_2_kwargs = {
    'mode': 'train',
    'valid_transform': valid_transform,
    'shifting': False,
    'rebuild_slices': False
}

train_ds_2 = HubDataset(root_dir, slices_path, transform, window, overlap, threshold, **ds_2_kwargs)

Reading cached slices, files and masks


  0%|          | 0/15 [00:00<?, ?it/s]

  s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


  0%|          | 0/583 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
train_ds_2

In [None]:
import matplotlib.pyplot as plt

def plot_img_mask(score, raw_image, index):
    fig, ax = plt.subplots(ncols=2, figsize=(15,15))
    resize_w = 250
    resize = 250
    ax[0].imshow(score)
    ax[0].set_title(f'Mask {index}')
    ax[0].set_axis_off()
    ax[1].imshow(np.moveaxis(raw_image, 0, -1))
    ax[1].set_title(f'Image {index}')
    ax[1].set_axis_off()

In [None]:
for i in range(10):
    image, mask = train_ds_2[i]
    image.shape, mask.shape, type(image)
    plot_img_mask(mask.squeeze().numpy(), image.numpy(), 0)

In [None]:
image.shape, mask.shape, mask.dtype

### DataLoaders

In [None]:
idx_all = np.arange(len(train_ds_2))
valid_idx = np.random.choice(len(train_ds_2), int(len(train_ds_2) * 0.05), replace=False )
train_idx = np.delete(idx_all, valid_idx)

In [None]:
train_ds = D.Subset(train_ds_2, train_idx)
valid_ds_2 = HubDataset(root_dir, slices_path, transform, window, overlap, threshold, **ds_2_kwargs)
valid_ds = D.Subset(valid_ds_2, valid_idx)
valid_ds.dataset.mode = 'valid'

In [None]:
for i in range(10):
    image, mask = valid_ds[i]
    print(image.shape, mask.shape, type(image))
    plot_img_mask(mask.squeeze().numpy(), image.numpy(), 0)

In [None]:
train_ds.dataset.is_convert_to_multiclass = True
valid_ds.dataset.is_convert_to_multiclass = True

train_dl = D.DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.n_jobs)
valid_dl = D.DataLoader(valid_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.n_jobs)

# train_dl = D.DataLoader(train_ds_2, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.n_jobs)
# valid_dl = D.DataLoader(train_ds_2, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.n_jobs)

In [None]:
train_ds.dataset.is_convert_to_multiclass = True

### Losses and Metrics

In [None]:
def calc_intersection_cardinality(y_pred, y, dims=(-2, -1)):
    x = y_pred
    x = torch.argmax(x, 1)
    y = torch.argmax(y, 1)
    intersection = (x * y).to(torch.int8).sum(dims)
    cardinality = (x + y).to(torch.int8).sum(dims)
    return intersection, cardinality

def dice_metric(y_pred, y, epsilon = 1e-7, dims=(-2, -1)):
    intersection, cardinality = calc_intersection_cardinality(y_pred, y)
    dc = (2 * intersection + epsilon) / (cardinality + epsilon)
    return dc.mean()

def iou_metric(y_pred, y, epsilon = 1e-7, dims=(-2, -1)):
    intersection, cardinality = calc_intersection_cardinality(y_pred, y)
    dc = (intersection + epsilon) / (cardinality - intersection + epsilon)
    return dc.mean()

dice_metric_2 = smp.utils.metrics.Fscore()

### Training

In [None]:
all_metrics = [dice_metric, dice_metric_2, iou_metric]

In [None]:
# from fastai.vision.all import *

In [None]:
# cross_entropy = CrossEntropyLossFlat(axis=1)

In [None]:
# from deepflash2.all import *
# cfg.metrics = [Iou(), Dice_f1()]

In [None]:
# dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=cfg.batch_size)
# if torch.cuda.is_available(): dls.cuda(), model.cuda()
# cbs = [SaveModelCallback(monitor='iou')]
# learn = Learner(dls, model, metrics=cfg.metrics, wd=cfg.weight_decay, loss_func=cross_entropy, opt_func=ranger, cbs=cbs)
# if cfg.mixed_precision_training: learn.to_fp16()

In [None]:
# Fit
# learn.fit_one_cycle(cfg.epochs, lr_max=cfg.max_learning_rate)
# learn.recorder.plot_metrics()

In [None]:
best_metric = 0
scheduler = torch.optim.lr_scheduler.OneCycleLR(cfg.optimizer, max_lr=cfg.max_learning_rate,
                                                steps_per_epoch=len(train_dl), epochs=cfg.epochs)

for epoch in tqdm(range(cfg.epochs)):  # loop over the dataset multiple times

    running_loss = 0.0
    tbar = tqdm(train_dl, position=0, leave=True)
    cfg.model.train()
    
    loss_sum = 0
    iou_sum = 0
    
    scaler = torch.cuda.amp.GradScaler() # mixed precision support
    
    for i, data in enumerate(tbar):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(DEVICE), labels.squeeze().float().to(DEVICE)
        if inputs.size(0) == labels.size(0):
            
            with torch.cuda.amp.autocast():
                # forward + backward + optimize
                outputs = cfg.model(inputs)
                loss = jaccard_loss(outputs, labels)
                iou = iou_metric(outputs, labels)
                
            scaler.scale(loss).backward()
            
            # Unscales the gradients of optimizer's assigned params in-place
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            scaler.step(cfg.optimizer)
            scale = scaler.get_scale()
            scaler.update()
            
            skip_lr_sched = (scale != scaler.get_scale())
            if not skip_lr_sched:
                scheduler.step()
                
            # zero the parameter gradients
            cfg.optimizer.zero_grad()
            
            loss_sum += loss
            iou_sum += iou

            # print statistics
            running_loss += loss.item()
        tbar.set_description(f'Train loss - {loss_sum.item() / (i + 1):.5f} iou - {iou_sum.item() / (i + 1):.5f}')

    print(f'Train Epoch {epoch}: Training loss {running_loss / len(train_dl):.5F}')
        
    tbar = tqdm(valid_dl, position=0, leave=True)
    cfg.model.eval()
    
    running_loss = 0.0
    iou_sum = 0
    metric_list = [[] for _ in all_metrics]
    with torch.no_grad():
        for i, data in enumerate(tbar):

            inputs, labels = data
            inputs, labels = inputs.to(DEVICE), labels.float().to(DEVICE)
            
            outputs = cfg.model(inputs)
            
            loss = jaccard_loss(outputs, labels)

            running_loss += loss.item()

            for ml, m in zip(metric_list, all_metrics):
                m_res = m(outputs, labels)
                ml.append(m_res.item())
                iou = m_res.item() # iou is the last item
                
            iou_sum += iou
                
            tbar.set_description(f'Valid loss - {running_loss / (i + 1):.5f} iou - {iou_sum / (i + 1):.5f}')

    dice_metric_mean = np.array(metric_list[0]).mean()
    dice_metric_mean_2 = np.array(metric_list[1]).mean()
    iou_metric_mean = np.array(metric_list[2]).mean()
    
    if dice_metric_mean > best_metric:
        best_metric = dice_metric_mean
        print('Saving Model')
        torch.save(cfg.model.state_dict(), 'models/hubmap_best_model_{epoch}_unet_pdf.pth')
        
    print(f'Valid Epoch {epoch}: Validation loss {running_loss / len(valid_dl):.5F}; dice_metric: {dice_metric_mean:.5F} {dice_metric_mean_2:.5F}; iou: {iou_metric_mean:.5F}')
        
print('Finished Training')

In [None]:
# metrics = [
#     smp.utils.metrics.IoU(),
#     dice_metric_2
# ]

# # create epoch runners 
# # it is a simple loop of iterating over dataloader`s samples
# train_epoch = smp.utils.train.TrainEpoch(
#     model, 
#     loss=jaccard_loss, 
#     metrics=metrics, 
#     optimizer=cfg.optimizer,
#     device=DEVICE,
#     verbose=True,
# )

# valid_epoch = smp.utils.train.ValidEpoch(
#     model, 
#     loss=jaccard_loss, 
#     metrics=metrics, 
#     device=DEVICE,
#     verbose=True,
# )

In [None]:
# max_score = 0

# for i in range(0, 4):
    
#     print('\nEpoch: {}'.format(i))
#     train_logs = train_epoch.run(train_dl)
#     valid_logs = valid_epoch.run(valid_dl)
    
#     # do something (save model, change lr, etc.)
#     if max_score < valid_logs['iou_score']:
#         max_score = valid_logs['iou_score']
#         torch.save(model.state_dict(), './best_model.pth')
#         print('Model saved!')
        
#     if i == 25:
#         optimizer.param_groups[0]['lr'] = 1e-5
#         print('Decrease decoder learning rate to 1e-5!')

In [None]:
# all_metrics = [dice_metric, dice_metric_2, iou_metric]

In [None]:
# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']
    
# def smooth_mask_2(mask, alpha = LABEL_SMOOTH):
#     return (1 - alpha) * mask + alpha / 2

# def train_epoch(model, dataloader, optim, criterion, scheduler, device="cpu", grad_accu_steps=GRAD_ACCU_STEPS):
    
#     train_loss = []
#     labels = []
#     outs = []
#     lrs = []
    
#     tbar = tqdm(dataloader, position=0, leave=True)
#     scaler = torch.cuda.amp.GradScaler() # mixed precision support
#     scale = None
#     for step, (image, target) in enumerate(tbar):
        
#         image, target = image.to(DEVICE), target.squeeze().float().to(DEVICE)
#         target = smooth_mask_2(target)
        
#         with torch.cuda.amp.autocast():
#             output = model(image)
#             loss = dice_ce_loss(output, target)
#             loss = loss  / grad_accu_steps
        
#         scaler.scale(loss).backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
#         if (step + 1) % grad_accu_steps == 0:
#             scaler.step(optim)
#             scale = scaler.get_scale()
#             scaler.update()
#             optim.zero_grad()
        
#         skip_lr_sched = (scale != scaler.get_scale())
#         if not skip_lr_sched:
#             scheduler.step()
        
#         loss_val = loss.item() * grad_accu_steps
#         iou_val = iou_metric(output, target)
#         train_loss.append(loss_val)
#         lrs.append(get_lr(optim))
        
#         tbar.set_description(f'loss - {loss_val:.5f} iou: {iou_val:.5f}')
        
#     print(f'Train loss: {np.array(train_loss).mean()}')
#     return train_loss, lrs

In [None]:
# def val_epoch(model, dataloader, criterion, epoch, device="cpu"):
#     model.eval()

#     valid_loss = []
#     num_corrects = 0
#     num_total = 0
#     labels = []
#     outs = []
#     metric_list = [[] for _ in all_metrics]

#     for item in dataloader:
#         image, target = item
#         image, target = image.to(DEVICE), target.float().to(DEVICE)

#         with torch.no_grad():
#             output = model(image)
#             loss = dice_ce_loss(output, target.squeeze())
#             for ml, m in zip(metric_list, all_metrics):
#                 m_res = m(output, target)
#                 ml.append(m_res.item())
#         valid_loss.append(loss.item())

#     avg_loss = np.array(valid_loss).mean()
#     print(f'Epoch {epoch} - valid loss: {avg_loss}')
#     dice_metric_mean = np.array(metric_list[0]).mean()
#     dice_metric_mean_2 = np.array(metric_list[1]).mean()
#     iou_metric_mean = np.array(metric_list[2]).mean()
#     return valid_loss, dice_metric_mean, avg_loss, iou_metric_mean, dice_metric_mean_2

In [None]:
# def train(epochs, train_dl, valid_dl, model, optimizer, scheduler, loss_fn, experiment_name, patience = 6, best_model = 'best_model.pth'):
    
#     best_model_path = Path("models")
#     best_model_path.mkdir(parents=True, exist_ok=True)
#     report_path = Path("reports")
#     report_path.mkdir(parents=True, exist_ok=True)
#     best_loss = 100.0
#     best_metric = 0
#     train_losses = []
#     valid_losses = []
#     accumulated_lrs = []
#     accumulated_dice_metrics = []
#     early_stop_counter = 0
#     messages = []

#     for epoch in tqdm(range(epochs), position=0, leave=True):
#         train_loss, lrs = train_epoch(model, train_dl, optimizer, loss_fn, scheduler, DEVICE)
#         valid_loss, dice_metric_mean, avg_loss, iou_metric_mean, dice_metric_mean_2 = val_epoch(model, valid_dl, loss_fn, epoch, DEVICE)
#         train_losses += train_loss
#         valid_losses.append(np.array(valid_loss).mean())
#         accumulated_lrs += lrs
#         accumulated_dice_metrics.append(dice_metric_mean)
#         if best_metric < dice_metric_mean:
#             best_metric = dice_metric_mean
#             print('Saving model')
#             if torch.cuda.device_count() > 1:
#                 torch.save(model.module.state_dict(), best_model_path/best_model)
#             else:
#                 torch.save(model.state_dict(), best_model_path/best_model)
#             early_stop_counter = 0
#         else:
#             early_stop_counter += 1
#         if best_loss > avg_loss:
#             best_loss = avg_loss
#         print(f'Epoch {epoch} - val best loss {best_loss} dice metric ({dice_metric_mean}, {dice_metric_mean_2}) iou metric ({iou_metric_mean}).')
#         messages.append({
#             'epoch': epoch,
#             'avg_loss': avg_loss,
#             'best_loss': best_loss,
#             'dice_metric_mean': dice_metric_mean,
#             'dice_coeff_mean': dice_metric_mean_2,
#             'iou_metric_mean': iou_metric_mean
#         })
#         with open(report_path/f'{experiment_name}', 'w') as outfile:
#             json.dump(messages, outfile)
#         if early_stop_counter >= patience:
#             print('Stopping early')
#             break
    
#     return train_losses, valid_losses, accumulated_lrs, accumulated_dice_metrics

In [None]:
# def learn(experiment_name, lr=1e-3, epochs=10, patience=7):
#     scheduler = torch.optim.lr_scheduler.OneCycleLR(cfg.optimizer, max_lr=lr,
#                                                     steps_per_epoch=len(train_dl), epochs=epochs)
#     train_losses, valid_losses, accumulated_lrs, accumulated_dice_metrics = train(epochs, 
#                                                                                   train_dl, 
#                                                                                   valid_dl,
#                                                                                   cfg.model,
#                                                                                   cfg.optimizer, 
#                                                                                   scheduler,
#                                                                                   cfg.loss_func,
#                                                                                   experiment_name,
#                                                                                   patience = patience)
#     return train_losses, valid_losses, accumulated_lrs, accumulated_dice_metrics

In [None]:
# train_losses, valid_losses, accumulated_lrs, accumulated_dice_metrics = learn(f'hub_map_pdf_sample_pytorch_{cfg.arch}_{ENCODER_NAME}_b{cfg.batch_size}', 
#                                                                               cfg.max_learning_rate, 
#                                                                               cfg.epochs, 
#                                                                               cfg.patience)

In [None]:
train_dl