<a href="https://colab.research.google.com/github/kreshuklab/teaching-dl-course-2019/blob/master/Webinars/exercise3/unet_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## The libraries

In [0]:
%matplotlib inline
%load_ext tensorboard
import os
import imageio
import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms, utils
from scipy.ndimage import binary_erosion

## Data loading and preprocessing

For this exercise we will be using the Kaggle 2018 Data Science Bowl data again, but this time we will try to segment it with the state of the art network.
Let's start with loading the data as before.

In [0]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1lEPEXGIxYeheiaHAp2G8rIK6Y3BGzVNN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1lEPEXGIxYeheiaHAp2G8rIK6Y3BGzVNN" -O kaggle_data.zip && rm -rf /tmp/cookies.txt
!unzip -qq kaggle_data.zip && rm kaggle_data.zip

Now make sure that the data was successfully extracted: if everything went fine, you should have folders `nuclei_train_data` and `nuclei_val_data` in your working directory. Check if it is the case:

In [0]:
!ls -ltrh

__TASK__: Use `ls` to explore the contents of both folders. Running `ls your_folder_name` should display you what is stored in the folder of your interest.

 How are the images stored? What format do they have? What about the ground truth (the annotation masks)? Which format are they stored in?

Hint: you can use the following function to display the images:

In [0]:
def show_one_image(image_path):
  image = imageio.imread(image_path)
  plt.imshow(image)

What one would normally start with in any machine learning pipeline is writing a dataset - a class that will fetch the training samples. In the previous exercises we did not have to worry about it, since we used the classic datasets available in the torchvision library. However, once you switch to using your own data, you would have to figure out how to fetch the data yourself. Luckily most of the functionality is already provided by PyTorch, but what you need to do is to write a class, that will actually supply the dataloader with training samples - a Dataset.

Please take a moment to read about it [here](https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset) and [here](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class).

The main idea: any Dataset class should have two methods: __len__ that returns the dataset length (the number of element) and __getitem__ that, given an index, returns input (image) and target (ground truth).

For this exercise you will not have to do it yourself yet, but please carefully read through the provided class:


In [0]:
#any PyTorch dataset class should inherit the initial torch.utils.data.Dataset
class NucleiDataset(Dataset):
    """ A PyTorch dataset to load cell images and nuclei masks """
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir  # the directory with all the training samples
        self.samples = os.listdir(root_dir) # list the samples
        self.transform = transform    # transformations to apply to both inputs and targets
        #  transformations to apply just to inputs
        self.inp_transforms = transforms.Compose([transforms.Grayscale(), # some of the images are RGB
                                                  transforms.ToTensor(),
                                                  transforms.Normalize([0.5], [0.5])
                                                  ])
        # transformations to apply just to targets
        self.mask_transforms = transforms.ToTensor()

    # get the total number of samples
    def __len__(self):
        return len(self.samples)

    # fetch the training sample given its index
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.samples[idx],
                                'images', self.samples[idx]+'.png')
        # we'll be using Pillow library for reading files
        # since many torchvision transforms operate on PIL images 
        image = Image.open(img_path)
        image = self.inp_transforms(image)
        masks_dir = os.path.join(self.root_dir, self.samples[idx], 'masks')
        # masks directory has multiple images - one mask per nucleus
        masks_list = os.listdir(masks_dir)
        # create an empty array
        mask = torch.zeros(1, len(image[0]), len(image[0][0]))
        # iterate through the images to sum them up to one mask
        for mask_name in masks_list:
            one_nuclei_mask = Image.open(os.path.join(masks_dir, mask_name))
            # erode the image by one pixel
            # TASK: guess why we are doing it
            one_nuclei_mask = binary_erosion(one_nuclei_mask)
            one_nuclei_mask = self.mask_transforms(one_nuclei_mask)
            # add this nucleus to the mask
            mask += one_nuclei_mask
        if self.transform is not None:
            image, mask = self.transform([image, mask])
        return image, mask

Now let's load the dataset and visualize it with a simple function:

In [0]:
TRAIN_DATA_PATH = 'nuclei_train_data'
train_data = NucleiDataset(TRAIN_DATA_PATH)

In [0]:
def show_dataset(dataset):
    idx = np.random.randint(0, len(dataset))    # take a random sample
    img, mask = dataset[idx]                    # get the image and the nuclei masks
    f, axarr = plt.subplots(1, 2)               # make two plots on one figure
    axarr[0].imshow(img[0])                     # show the image
    axarr[1].imshow(mask[0])                    # show the masks
    _ = [ax.axis('off') for ax in axarr]        # remove the axes
    plt.show()

In [0]:
show_dataset(train_data)



As you can probably see, if you clicked enough times, some of the images are really huge! What happens if we load them into memory and run the model on them? We might run out of memory. That's why normally, when training networks on images or volumes one has to be really careful about the sizes. In practice, you would want to regulate their size. Additional reason for restraining the size is: if we want to train in batches (faster and more stable training), we need all the images in the batch to be of the same size. That is why we prefer to either resize or crop them.

Here is a function (well, actually a class), that will apply a transformation 'random crop'. Notice that we apply it to images and masks simultaneously to make sure they correspond, despite the randomness.

In case anybody is wondering why we have to bother to write a whole class for it instead of simply coping the images directly in the dataset: we want to keep the code modular. We want to write one dataset object, and then we can try all the possible transforms with this one dataset. Similarly, we want to write one Randomcrop transform object, and then we can reuse it for any other image datasets we night have in the future.


In [0]:
class RandomCrop(object):
    """Crop randomly the input image and the output mask"""
    def __init__(self, crop_size):
        # check if the crop size is of a valid type
        assert isinstance(crop_size, (int, tuple, list))
        if isinstance(crop_size, int):
            # if the crop size is an integer, we use the same for both dimensions
            self.output_size = (crop_size, crop_size)
        else:
            assert len(crop_size) == 2
            self.crop_size = crop_size

    # this function makes our class callable 
    def __call__(self, sample):
        # we need to crop both input and mask at the same time
        assert len(sample) == 2
        image, mask = sample
        # the first dimension is channels, then width, then height
        w, h = image.shape[1:]
        new_w, new_h = self.output_size
        # choose a random place to crop
        top = np.random.randint(0, h - new_h) if h - new_h > 0 else 0
        left = np.random.randint(0, w - new_w) if w - new_w > 0 else 0
        # crop and return
        image = image[:, left: left + new_w, top: top + new_h]
        mask = mask[:, left: left + new_w, top: top + new_h]
        return image, mask

PS: PyTorch already has quite a bunch of all possible data transforms, so if you need one, check [here](https://pytorch.org/docs/stable/torchvision/transforms.html). The biggest problem with them is that they are clearly separated into transforms applied to PIL images (remember, we initially load the images as PIL.Image?) and torch.tensors (remember, we converted the images into tensors by calling transforms.ToTensor()?). This can be incredibly annoying if for some reason you might need to transorm your images to tensors before applying any other transforms or you don't want to use PIL library at all.

In [0]:
train_data = NucleiDataset(TRAIN_DATA_PATH, RandomCrop(256))

In [0]:
show_dataset(train_data)

And the same for the validation data:

In [0]:
VAL_DATA_PATH = 'nuclei_val_data'
val_data = NucleiDataset(VAL_DATA_PATH, RandomCrop(256))

In [0]:
show_dataset(val_data)

## The model: U-net

Now we need to define the architecture of the model to use. This time we will use a [U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) that has proven to steadily outperform the other architectures in segmenting biological and medical images.

The image of the model precisely describes all the building blocks you need to use to create it. All of them can be found in the list of PyTorch layers (modules) [here](https://pytorch.org/docs/stable/nn.html#convolution-layers).

The U-net has an encoder-decoder structure:

In the encoder pass, the input image is successively downsampled via max-pooling. In the decoder pass it is upsampled again via transposed convolutions.

In adddition, it has skip connections, that bridge the output from an encoder to the corresponding decoder.