In [1]:
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
import glob
import os
import zarr
import random
import datetime
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import h5py
from imgaug import augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from imgaug.augmentables.heatmaps import HeatmapsOnImage
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchsummary import summary
from utils.colormap import *
from unet_fov import *
from utils.mean_shift import MeanShift
from skimage import measure
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True

Tile and Stitch
==================================

When applying one of our deep learning models on microscopy data, it might be that the images are too big to be stored on the GPU and predicted at once. In such situations, the image is subdivided into **tiles**, the **tiles** are fed into the DL model and the respective predictions are stitched together to form an image of the same resolution as the original microscopy image that was subject to inference. 

For a more detailed reference, please see https://arxiv.org/pdf/2101.05846.pdf


Data
-------
For this task we use again a subset of the data used in the kaggle data science bowl 2018 challenge
(https://www.kaggle.com/c/data-science-bowl-2018/)

![image.png](utils/attachment/image.png)
All images show nuclei recorded using different microscopes and lighting conditions.
There are 30 images in the training set, 8 in the validation set and 16 in the test set.

#TODO: Create the KaggleDSB_dataset
-------
We will create the KaggleDSB_dataset, a subclass which inherits from torch.utils.data.Dataset.

When you just have limited number of data for training, data augmentation is essential to get good results.

TODO: Implement the part of **define_augmentation** for training data during training on the fly.Think about what kind of augmentation to use (e.g. flips, rotation, elastic).Use the imgaug library (https://imgaug.readthedocs.io/en/latest/), it provides a very extensive list of available augmentations.

In [2]:
# decompress data
from shutil import unpack_archive
unpack_archive(os.path.join('datasets','data_kaggle.tar.gz'), './')

In [None]:
class KaggleDSB_dataset(Dataset):
    """(subset of the) kaggle data science bowl 2018 dataset.
    The data is loaded from disk on the fly and in parallel using the torch dataset class.
    This enables the use of datasets that would not fit into main memory and dynamic augmentation.
    Args:
        root_dir (string): Directory with all the images.
        data_type (string): train/val/test, select subset of images
        prediction_type (string): default to be "metric_learning" for this notebook
        net_input_size (list): the input title size of you UNet
        padding_size (int): the number of pixels to pad on each side of the image before augmentation and cropping
        cache: if cache the data, default: False
    """
    def __init__(self,
                 root_dir,
                 data_type,
                 prediction_type="two_class",
                 net_input_size=None,
                 padding_size=None,
                 cache = False
                ):
        self.data_type = data_type
        self.files = glob.glob(os.path.join(root_dir, data_type, "*.zarr"))
        self.prediction_type = prediction_type
        self.net_input_size = net_input_size
        self.padding_size = padding_size
        self.define_augmentation()
        self.cache = cache
        if cache:
            self.cached_data = [self.load_sample(filename) for filename in self.files]

    def __len__(self):
        return len(self.files)
    
    def define_augmentation(self):
        
        self.transform = iaa.Identity
        self.crop = None
        self.pad = None

        ###########################################################################
        # TODO (optional): Define your augmentation pipeline and uncomment the    #
        # following code                                                          #
        ###########################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        
        # define self.transfrom by looking into the imgaug package reference
        
        # self.transform = iaa.Sequential([
        #     ...,
        #     ...,
        #    ...
        # ], random_order=True)
        
        
        # if self.net_input_size is not None:
        #     self.crop = ... which augmentation?

        # if self.padding_size is not None:
        #     self.pad =  ... which augmentation
                    
            
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        ###########################################################################
        #                             END OF YOUR CODE                            #
        ###########################################################################
        
    def get_filename(self, idx):
        return self.files[idx]
        
    def __getitem__(self, idx):
        if self.cache:
            raw, label = self.cached_data[idx]
        else:
            fn = self.get_filename(idx)
            raw, label = self.load_sample(fn)
        raw = self.normalize(raw)
        # augment for training
        if self.padding_size is not None:
            raw = self.pad(images = raw) # CHW -> CHW
            label = self.pad(images = label) # CHW -> CHW
        if self.data_type == "train":
            raw = np.transpose(raw, [1,2,0]) # CHW -> HWC
            label = np.transpose(label, [1,2,0]) # CHW -> HWC            
            raw, label = self.augment_sample(raw, label) # HWC -> HWC
            raw = np.transpose(raw, [2,0,1]) # HWC -> CHW
            label = np.transpose(label, [2,0,1]) # HWC -> CHW
        if self.net_input_size is not None:
            tmp = np.concatenate([raw, label], axis = 0).copy() # C1+C2 HW
            tmp = np.transpose(tmp, [1,2,0]) # CHW -> HWC 
            tmp = self.crop.augment_image(tmp) # HWC -> HWC
            tmp = np.transpose(tmp, [2,0,1])
            raw, label = np.expand_dims(tmp[0], axis=0), np.stack(tmp[1:],axis=0) # split
        raw, label = torch.Tensor(raw), torch.Tensor(label)
        return raw, label
    
    def augment_sample(self, raw, label):
        # stores float label (sdt) differently than integer label (rest)
        if self.prediction_type in ["sdt"]:
            label = HeatmapsOnImage(label, shape=raw.shape, min_value=-1.0, max_value=1.0)
            raw, label = self.transform(image=raw, heatmaps=label)
        else:
            label = label.astype(np.int32)
            label = SegmentationMapsOnImage(label, shape=raw.shape)
            raw, label = self.transform(image=raw, segmentation_maps=label)
            
        label = label.get_arr() 
        # some pytorch version have problems with negative indices introduced by e.g. flips
        # just copying fixes this
        label = label.copy()
        raw = raw.copy()
        return raw, label
    
    def normalize(self, raw):
        # z-normalization
        raw -= np.mean(raw)
        raw /= np.std(raw)
        return raw
    
    def load_sample(self, filename):
        data = zarr.open(filename)
        raw = np.array(data['volumes/raw'])
        if self.prediction_type == "two_class":
            label = np.array(data['volumes/gt_fgbg'])
        elif self.prediction_type == "affinities":
            label = np.array(data['volumes/gt_affs'])
        elif self.prediction_type == "sdt":
            label = np.array(data['volumes/gt_tanh'])
        elif self.prediction_type == "three_class":
            label = np.array(data['volumes/gt_threeclass'])
        elif self.prediction_type == "metric_learning":
            label = np.array(data['volumes/gt_labels'])
        label = label.astype(np.float32)
        return raw, label

Loss
-------
Here we will just use the metric learning approach with the well-known discriminative loss that you are already familiarized with in the instance segmentation part of the exercises.

### #TODO: Metric Learning ###
In metric learning your model learns to predict an embedding vector for each pixel. These embedding vectors are learned such that vectors from pixels belonging to the same instance are similar to each other and dissimilar to the embedding vectors of other instances and the background. It can also be thought of as learning a false coloring where each instance is colored with a unique but arbitrary color.  
![metric_learning.png](utils/attachment/metric_learning.png)

TODO: Please fill in the missing code according to the knowledge you learnt from the instance_segmentation.ipynb exercise.


In [None]:
from utils.disc_loss import DiscriminativeLoss

#hint: see here for torch tensor types: https://pytorch.org/docs/stable/tensors.html

prediction_type = "metric_learning"
###########################################################################
# TODO                                                                    #
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
out_channels = 
activation = 
loss_fn = 
dtype = 
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################

Create our input datasets, ground truth labels are chosen depending on the type:

In [None]:
# make datasets
root = 'data_kaggle_test'
padding_size = 46
batch_size = 1
net_input_size = [332,332]

data_train = KaggleDSB_dataset(root, "train", prediction_type=prediction_type, padding_size=padding_size, net_input_size=net_input_size, cache=False)
data_val = KaggleDSB_dataset(root, "val", prediction_type=prediction_type, padding_size=padding_size,net_input_size=net_input_size, cache=False)
data_test = KaggleDSB_dataset(root, "test", prediction_type=prediction_type, padding_size=padding_size)
# make dataloaders
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(data_val, batch_size=1, pin_memory=True)
test_loader = DataLoader(data_test, batch_size=1)

Let's have a look at some of the raw data and labels:


In [None]:
# repeatedly execute this cell to get different images
for image, label in data_train:
    break

label = np.squeeze(label, 0)
if prediction_type == "affinities":
    label = label[0] + label[1]

fig=plt.figure(figsize=(12, 8))
fig.add_subplot(1, 2, 1)
plt.imshow(np.squeeze(image), cmap='gray')
fig.add_subplot(1, 2, 2)
plt.imshow(np.squeeze(label), cmap='gist_earth')
plt.show()

#TODO: Define our U-Net
==============
As before, we define our neural network architecture and can choose the depth and number of feature maps at the first convolution. Our CNN used 'same' padding before, which assures that spatial dimensions are not reduced by the convolution operation. Unfortunately, this padding scheme renders the network location-aware, since it can learn to calculate the distance of most pixels in the image to the padded zeros at the image boundaries. This location awareness then leads to discontinuities at the stitching boundaries. Therefore we'll use 'valid' padding instead of 'same' padding.

Valid padding actually means that input filter maps are not padded at all and therefore the spatial dimensions of the filter maps reduce after every convolution. This leads to the network predicting smaller tiles than the ones that got fed into the network as you can see in the network summary below, e.g choosing the net_input_size = $[332,332]$, our output tile size will be $[148,148]$ with our defined UNet structure. Therefore the input tiles need to overlap such that the output tiles align.

TODO:You need to build a Unet with the following conditions, otherwise you will run into errors when running the next cell.
- 1. Use valid padding.
- 2. The number of feature maps in the first layer should be 32 as well as the number of out feature maps.
- 3. The feature maps should go through 4 times downsampling and 4 times upsampling with the scale factor=2.
- 4. fmap_inc_factor which refers to the number of feature maps between layers should be set to 2.

In [None]:
from unet_fov import UNet

torch.manual_seed(42)
net_input_size = [332,332]

# hint: see network requirements above and torch.nn.Sequential for building network
# hint: we defined out channels in a previous cell. how many output channels do we get from the Network?
# hint: how can we get to the right number of output channels? 

###########################################################################
# TODO: Define the net and uncomment the following code                   #
# Please define a UNet which use valid padding                             #
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
# d_factors = 
# net = torch.nn.Sequential(
#     UNet(in_channels=,
#     num_fmaps=,
#     fmap_inc_factors=,
#     downsample_factors=d_factors,
#     activation='ReLU',
#     padding=,
#     num_fmaps_out=,
#     constant_upsample=False
#     ),
#     torch.nn.Conv2d(in_channels=, out_channels=out_channels, kernel_size=1, padding=0, bias=True))

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
summary(net, (1, net_input_size[0], net_input_size[1]))

Training
=======

You already know the training loop from the other notebooks. For this execise we already trained a network like the one that you just defined for 60,000 training steps. You can load this CNN by executing the next cell.

In [None]:
#load pretrained network (checkpoint unet_60000)

net=torch.load("utils/unet_60000",map_location=device)

Postprocessing
=============

Here we first tile the input image and feed the tiles to the network to generate our predictions. Then the tiles are cropped and stitched to form the final embeddings. In the lecture you learned that U-Nets are shift equivariant only for shifts that are multiples of $f^l$, since our network does pooling with window size and stride $2$ and has $4$ pooling layers $f^l=2^4=16$. When stitching output crops of this U-Net whose spatial dimensions are not multiples of $16$ you might observe discontinuities at the stitching boundaries that lead to false splits. 

Try running the below cell and then try the questions in the following cell!

In [None]:
%reload_ext autoreload
%autoreload 2
from utils.label import *
from utils.evaluate import *

crop_size =  [50, 50]
output_size = [148,148]
padding_size = 0
data_test = KaggleDSB_dataset(root, "test", prediction_type=prediction_type)
test_loader = DataLoader(data_test, batch_size=1)

# set flag
net.eval()
# set hyperparameters
prediction_type == "metric_learning"
fg_thresh = 0.7
seed_thresh = None
    
def unpad(pred, padding_size):
    return pred[padding_size:-padding_size,padding_size:-padding_size]

avg = 0.0
pad_top_left = (net_input_size[0]-output_size[0])//2
for idx, (image, gt_labels) in enumerate(test_loader):
    image = image.to(device)
    if crop_size:
        output_size = crop_size
    image_padded = F.pad(image, (pad_top_left,300,pad_top_left,300))
    H,W = image.shape[-2:]
    patched_pred = torch.zeros(1,4,image.shape[2]+200,image.shape[3]+200)
    for h in range(0, H, output_size[0]):
        for w in range(0, W, output_size[1]):
            image_tmp = image_padded[:,:,h:h+net_input_size[0],w:w+net_input_size[1]]
            pred = net(image_tmp)
            if crop_size:
                pred = pred[:,:,:crop_size[0],:crop_size[1]]
            patched_pred[:,:,h:h+pred.shape[2],w:w+pred.shape[3]] = pred
    pred = patched_pred[:,:,0:360,0:360]
    image = np.squeeze(image.cpu())
    gt_labels = np.squeeze(gt_labels)
    pred = np.squeeze(pred.cpu().detach().numpy(),0)
    #if prediction_type in ["three_class", "affinities","two_class","sdt"]:
    if padding_size and padding_size>0:
        pred = unpad(np.transpose(pred,(1,2,0)), padding_size)
        pred = np.transpose(pred,(2,0,1))
    labelling, surface = label(pred, prediction_type, fg_thresh=fg_thresh, seed_thresh=seed_thresh)
    ap, precision, recall, tp, fp, fn = evaluate(labelling, data_test.get_filename(idx))
    avg += ap
    print(np.min(surface), np.max(surface))
    labelling = labelling.astype(np.uint8)
    print("average precision: {}, precision: {}, recall: {}".format(ap, precision, recall))
    print("true positives: {}, false positives: {}, false negatives: {}".format(tp, fp, fn))
    if prediction_type == "metric_learning":
        surface = surface+np.abs(np.min(surface, axis=(1,2)))[:,np.newaxis,np.newaxis]
        surface /= np.max(surface, axis=(1,2))[:,np.newaxis,np.newaxis]
        surface = np.transpose(surface, (1,2,0))
    
    fig=plt.figure(figsize=(16, 8))
    ax = fig.add_subplot(1, 4, 1)
    ax.set_title("raw")
    plt.imshow(np.squeeze(image))
    ax = fig.add_subplot(1, 4, 2)
    ax.set_title("gt labels")
    plt.imshow(np.squeeze(1.0-gt_labels))
    
    ax = fig.add_subplot(1, 4, 3)
    ax.set_title("prediction")
    plt.imshow(np.squeeze(1.0-surface))
    ax = fig.add_subplot(1, 4, 4)
    ax.set_title("pred segmentation")
    plt.imshow(np.squeeze(labelling), cmap=rand_cmap, interpolation="none")
    bs = output_size[0]
    lw = 1
    for k in range(bs, labelling.shape[-1], bs):
        of = -0.5
        ax.plot([k+of, k+of], [0, labelling.shape[-2]-1], color='red', linestyle='--', dashes=(5, 10), linewidth=lw)
    for l in range(bs, labelling.shape[-2], bs):
        of = -0.5
        ax.plot([0, labelling.shape[-1]-1], [l+of, l+of], color='red', linestyle='--', dashes=(5, 10),linewidth=lw)

    break
    plt.show()
avg /= (idx+1)
print("average precision on test set: {}".format(avg))

In [None]:
# Try changing the crop_size above and compare the different results.
# What happens when we use a crop size that is a multiple of 16?
# What happens when we use a low crop size (e.g [20,20])? Why?
# Try iterating over full data loader and see different results

In [None]:
# solutions

# we see discontinuties at the stitching boundaries if the crop size is a multiple of 16
# we get out of memory errors on the gpu with a low crop size - iterating over more patches
# remove break statement