## Neural networks for segmentation

In [None]:
# ! wget https://www.dropbox.com/s/jy34yowcf85ydba/data.zip?dl=0 -O data.zip
# ! unzip -q data.zip

Your next task is to train neural network to segment cells edges.

Here is an example of input data with corresponding ground truth:

In [None]:
import scipy as sp
import scipy.misc
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
# Human HT29 colon-cancer cells
plt.figure(figsize=(10,8))
plt.subplot(1,2,1)
im = sp.misc.imread('BBBC018_v1_images-fixed/train/00735-actin.DIB.bmp')
plt.imshow(im)
plt.subplot(1,2,2)
mask = sp.misc.imread('BBBC018_v1_outlines/train/00735-cells.png')
plt.imshow(mask, 'gray')

This time you aren't provided with any code snippets, just input data and target metric - intersection-over-union (IoU) (see implementation below).

You should train neural network to predict mask of edge pixels (pixels in gt images with value greater than 0).

Use everything you've learnt by now: 
* any architectures for semantic segmentation (encoder-decoder like or based on dilated convolutions)
* data augmentation (you will need that since train set consists of just 41 images)
* fine-tuning

You're not allowed to do only one thing: to train you network on test set.

Your final solution will consist of an ipython notebook with code (for final network training + any experiments with data) and an archive with png images with network predictions for test images (one-channel images, 0 - for non-edge pixels, any non-zero value for edge pixels).

Forestalling questions about baseline... well, let's say that a good network should be able to segment images with iou >= 0.29. This is not a strict criterion of full points solution, but try to obtain better numbers.

Practical notes:
* There is a hard data class imbalance in dataset, so the network output will be biased toward "zero" class. You can either tune the minimal probability threshold for "edge" class, or add class weights to increase the cost of edge pixels in optimized loss.
* Dataset is small so actively use data augmentation: rotations, flip, random contrast and brightness
* Better spend time on experiments with neural network than on postprocessing tricks (i.e test set augmentation).
* Keep in mind that network architecture defines receptive field of pixel. If the size of network input is smaller than receptive field of output pixel, than probably you can throw some layers without loss of quality. It is ok to modify "of-the-shelf" architectures. 

Good luck!

In [None]:
def calc_iou(prediction, ground_truth):
    n_images = len(prediction)
    intersection, union = 0, 0
    for i in range(n_images):
        intersection += np.logical_and(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum() 
        union += np.logical_or(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum()
    return float(intersection) / union

I used UNet with some layers pretrained (from the pretrained VGG13_bn) and two loss functions: Jaccard and BinaryCrossEntropyWithLogitsLoss (with pos_weight of 10-25). Both ways best achieved model had validation IoU > 0.25.

In [None]:
import torch
from torchvision import transforms as T, models as M
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os, glob, time, copy, random
from PIL import Image

In [None]:
class CancerCellDataset(Dataset):
    def __init__(self, image_path, mask_path, image_transform=None, mask_transform=None):
        super(CancerCellDataset, self).__init__()
        
        # Get sorted image/mask paths.
        images = glob.glob(os.path.join(image_path, '*.bmp'))
        masks = glob.glob(os.path.join(mask_path, '*.png'))
        
        # Sort filenames to 
        images.sort()
        masks.sort()
        
        self.data = list(zip(images, masks))
        
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __getitem__(self, idx):
        image_path, mask_path = self.data[idx]
        
        image = Image.open(image_path)
        mask = Image.open(mask_path)
        
        # "Synchronize" transformation of images and masks using same random seed.
        seed = np.random.randint(2147483647)
        if self.image_transform is not None:
            random.seed(seed)
            image = self.image_transform(image)
        if self.mask_transform is not None:
            random.seed(seed)
            mask = self.mask_transform(mask)
        
        mask = (mask > 0).float()
        return image, mask

    def __len__(self):
        return len(self.data)

In [None]:
# Prepare the data
# Some code parts may be stylized as in, or be directly from https://pytorch.org/tutorials/index.html

image_dir = './BBBC018_v1_images-fixed/'
mask_dir = './BBBC018_v1_outlines/'

stages = ['train', 'val']

image_transforms = {
    'train': T.Compose([T.RandomHorizontalFlip(),
                        T.RandomAffine(180, (.4, .4), (.8, 1.2), 30, Image.BILINEAR),
                        T.ColorJitter(brightness=0.2, contrast=0.2),
                        T.ToTensor()]),
    'val'  : T.ToTensor()}
mask_transforms = {
    'train': T.Compose([T.RandomHorizontalFlip(),
                        T.RandomAffine(180, (.2, .2), (.8, 1.2), 20, Image.BILINEAR),
                        T.ToTensor()]),
    'val'  : T.ToTensor()}


datasets = {stage: CancerCellDataset(os.path.join(image_dir, stage),
                                     os.path.join(mask_dir, stage),
                                     image_transform=image_transforms[stage],
                                     mask_transform=mask_transforms[stage])
            for stage in stages}

batch_sizes = {'train': 3,
               'val': 7}

dataloaders = {stage: torch.utils.data.DataLoader(datasets[stage],
                                                  batch_size=batch_sizes[stage],
                                                  shuffle=True,
                                                  num_workers=32,
                                                  drop_last=False)
               for stage in stages}

dataset_sizes = {stage: len(datasets[stage]) for stage in stages}

device = torch.device("cuda:0")

In [None]:
# U-net from the paper https://arxiv.org/pdf/1505.04597.pdf
# Implementated using some ideas from 
# https://github.com/milesial/Pytorch-UNet and https://github.com/ternaus/TernausNet

class DoubleConv2d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv2d, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class Down(nn.Module):
    def __init__(self, in_channels):
        super(Down, self).__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.double_conv = DoubleConv2d(in_channels, in_channels*2)
    
    def forward(self, x):
        x = self.max_pool(x)
        x = self.double_conv(x)
        return x

class Up(nn.Module):
    def __init__(self, in_channels):
        super(Up, self).__init__()
        self.pad = nn.ZeroPad2d((0,1,0,1))
        self.conv_up = nn.Conv2d(in_channels, in_channels//2, kernel_size=2)
        self.relu = nn.ReLU()
        self.double_conv = DoubleConv2d(in_channels, in_channels//2)
    
    def forward(self, x, copy):
        # Up-Convolution
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.pad(x)
        x = self.relu(self.conv_up(x))

        # Concatenate
        x = torch.cat([copy, x], dim=1)
        
        # Two Conv layers
        x = self.double_conv(x)
        return x
    
class UNet(nn.Module):
    def __init__(self, pretrained_encoder=False):
        super(UNet, self).__init__()
        self.inp = DoubleConv2d(3, 64)
        self.down1 = Down(64)
        self.down2 = Down(128)
        self.down3 = Down(256)
        self.down4 = Down(512)
        self.up1 = Up(1024)
        self.up2 = Up(512)
        self.up3 = Up(256)
        self.up4 = Up(128)
        self.out = nn.Conv2d(64, 1, kernel_size=1)
        
        if pretrained_encoder == True:
            # Load some layers from pretrained vgg13_bn.
            encoder = M.vgg13_bn(pretrained=True).features
            
            blocks = [self.inp, self.down1, self.down2, self.down3]
            first_indexes = [0, 7, 14, 21]
            for i, block in zip(first_indexes, blocks): 
                block.conv1 = encoder[i]
                block.bn1 = encoder[i+1]
                block.conv2 = encoder[i+2]
                block.bn2 = encoder[i+3]
            
    def forward(self, x):
        x1 = self.inp(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.out(x)
        return x

In [None]:
# Training loop
loss_history = {'train': [], 'val': []}
best_acc = 0.0

def train_model(model, criterion, optimizer, scheduler, num_epochs=10,  acc=0.0):
    since = time.time()

    global best_model
    best_model = copy.deepcopy(model.state_dict())
    
    global best_acc
    best_acc = acc

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            since_epoch = time.time()
        
            # Set the model mode
            model.train() if phase == 'train' else model.eval()

            running_loss = 0.0
            predictions = []
            ground_truths = []

            # Iterate over data.
            for inputs, masks in dataloaders[phase]:
                inputs = inputs.to(device)
                masks = masks.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    
                    if criterion.__name__ == 'jaccard_loss':
                        loss = criterion(masks.long(), outputs)
                    else:
                        loss = criterion(outputs.view(-1), masks.view(-1))
    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # zero the parameter gradients
                optimizer.zero_grad()

                # statistics
                batch_size = inputs.size(0)
                running_loss += loss.item() * batch_size
                # Store masks to calculate for the whole epoch
                predicted_mask = (outputs.detach() >= 0).cpu().numpy()
                masks = masks.cpu().numpy()
                predictions.append(predicted_mask)
                ground_truths.append(masks)
            
            # Verbose
            epoch_loss = running_loss / dataset_sizes[phase]
            predictions = np.concatenate(predictions)
            ground_truths = np.concatenate(ground_truths)
            epoch_iou = calc_iou(predictions, ground_truths)
            print('{} Loss: {:.4f} IoU: {:.4f} Time: {:.1f} sec'.format(
                phase, epoch_loss, epoch_iou, time.time() - since_epoch))
            
            loss_history[phase].append(epoch_loss)

            # deep copy the model
            if phase == 'val':
                if epoch_iou > best_acc:
                    best_acc = epoch_iou
                    best_model = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val IoU: {:4f}'.format(best_acc))

    return model

In [None]:
def maskshow(i=0):
    """Visualize predicted mask probabilities, predicted mask, 
    and ground truth mask for a validation smaple.
    Args:
        i: number of validation sample.
    """
    inp, mask = datasets['val'][i]
    out = net(inp.to(device)[None,...])

    mask_pred = out[0,0].detach().sigmoid().cpu().numpy()
    mask = mask[0].cpu().numpy()
    
    plt.figure(figsize=(15,5))
    plt.subplot(1,3,1)
    plt.imshow(mask_pred, 'gray', vmax=1, vmin=0)
    plt.subplot(1,3,2)
    plt.imshow(mask_pred.round(), 'gray', vmax=1, vmin=0)
    plt.subplot(1,3,3)
    plt.imshow(mask, 'gray', vmax=1, vmin=0)

In [None]:
# Jaccard Loss - code from https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
# From the paper - https://www.cs.toronto.edu/~urtasun/publications/mattyus_etal_iccv17.pdf
def jaccard_loss(true, logits, eps=1e-7):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        true: a tensor of shape [B, H, W] or [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = F.softmax(probas, dim=1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    union = cardinality - intersection
    jacc_loss = (intersection / (union + eps)).mean()
    return (1 - jacc_loss)

In [None]:
net = UNet(pretrained_encoder=True)
net = net.to(device)

# pos to neg in the train dataset is .961 : .039
# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(.961 / .039, dtype=torch.float32, device=device))
criterion = jaccard_loss

optimizer = optim.SGD(net.parameters(), lr=.01, momentum=0.9, weight_decay=0.0001)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
# Network output before training
maskshow()

In [None]:
net = train_model(net, criterion, optimizer, scheduler, num_epochs=20)

In [None]:
# Network output after training
maskshow(3)

In [None]:
plt.figure(figsize=(15,8))
plt.plot(np.arange(1,1+len(loss_history['train'])), loss_history['train'], label='train')
plt.plot(np.arange(1,1+len(loss_history['val'])), loss_history['val'], label='val')
plt.title('Loss')
plt.xlabel('epoch')
plt.legend()
plt.grid()
# plt.ylim(.37, .45)
plt.show()

In [None]:
### Save the network
# torch.save(net, 'unet.pth.tar')
# torch.save(net.state_dict(), 'unet_dict.pth.tar')
# loss = np.concatenate((np.array(loss_history['train'])[:, None], np.array(loss_history['val'])[:, None]), axis=1)
# np.savetxt('unet_loss.txt', loss)

In [None]:
# net = UNet(pretrained_encoder=False)

# # load the model
# net.load_state_dict(torch.load('unet_dict.pth.tar'))
# net = net.to(device)

# # load the loss history
# loss_history_np = np.loadtxt('unet_loss.txt')
# loss_history = {'train': list(loss_history_np[:,0]), 'val': list(loss_history_np[:,1])}