<a href="https://colab.research.google.com/github/kevinkevin556/Dlchemist/blob/main/unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **U-Net on ISBI 2012 Dataset (Electron Microscopic stacks)**

In [1]:
! pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m726.1 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [None]:
import os
from google.colab import drive

# Mount your google drive to save training checkpoints.
drive.mount('/content/gdrive')

# Prepare directory for PATH.
directory = f"/content/gdrive/MyDrive/Colab Checkpoints/U-Net/"
if not os.path.exists(directory):
  os.makedirs(directory)
  
PATH = directory + "state_dict.ckpt"

Mounted at /content/gdrive


In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch import tensor
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# **Section 1. Set up**

## **Dataset**

In [3]:
!gdown https://downloads.imagej.net/ISBI-2012-challenge.zip
!unzip /content/ISBI-2012-challenge.zip -x / -d .

Downloading...
From: https://downloads.imagej.net/ISBI-2012-challenge.zip
To: /content/ISBI-2012-challenge.zip
100% 31.5M/31.5M [00:03<00:00, 8.40MB/s]
Archive:  /content/ISBI-2012-challenge.zip
 extracting: ./test-volume.tif       
 extracting: ./test-labels.tif       
 extracting: ./train-labels.tif      
 extracting: ./train-volume.tif      
 extracting: ./challenge-error-metrics.bsh  


In [4]:
from PIL import Image
from torch.utils.data import Dataset

class ISBI2012(Dataset):
    def __init__(self, volume_path, labels_path, indices=None, transform=None, target_transform=None):
        self.volume =  Image.open(volume_path)
        self.labels =  Image.open(labels_path)
        assert self.volume.n_frames == self.labels.n_frames
        self.indices = indices

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.volume.n_frames

    def __getitem__(self, idx):
        if self.indices is not None:
            idx = self.indices[idx]

        self.volume.seek(idx)
        image = self.volume
        if self.transform is not None: 
            image = self.transform(image)

        self.labels.seek(idx)
        label = self.labels
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [5]:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

total_size = 30
valid_size = total_size // 5
rng = np.random.default_rng(42)
indices = rng.permutation(total_size)
train_id, valid_id = indices[:-valid_size], indices[valid_size:]

train_dataset = ISBI2012('./train-volume.tif', './train-labels.tif', train_id,
                         transform=ToTensor(),
                         target_transform=ToTensor())
valid_dataset = ISBI2012('./train-volume.tif', './train-labels.tif', valid_id,
                         transform=ToTensor(),
                         target_transform=ToTensor())
test_dataset = ISBI2012('./test-volume.tif', './test-labels.tif',
                         transform=ToTensor(),
                         target_transform=ToTensor())


print("\nData Size:")
print("* Training set   => ", len(train_dataset), "images")
print("* Validation set => ", len(valid_dataset), "images")
print("* Testing set    => ", len(test_dataset), "images")

print("\nImage Shape (C, H, W):")
print("* Training image:", train_dataset[0][0].shape)
print("* Validation image: ", valid_dataset[0][0].shape)
print("* Testing image: ", test_dataset[0][0].shape)



Data Size:
* Training set   =>  30 images
* Validation set =>  30 images
* Testing set    =>  30 images

Image Shape (C, H, W):
* Training image: torch.Size([1, 512, 512])
* Validation image:  torch.Size([1, 512, 512])
* Testing image:  torch.Size([1, 512, 512])


In [6]:
train_dataset[0][0]

tensor([[[0.2902, 0.3529, 0.2667,  ..., 0.5451, 0.4471, 0.5176],
         [0.2471, 0.2902, 0.1882,  ..., 0.4235, 0.4196, 0.4039],
         [0.1451, 0.1686, 0.1647,  ..., 0.3961, 0.3922, 0.3137],
         ...,
         [0.5490, 0.5294, 0.4824,  ..., 0.8353, 0.7451, 0.7255],
         [0.3686, 0.4118, 0.4863,  ..., 0.7882, 0.7608, 0.7529],
         [0.3804, 0.3529, 0.3882,  ..., 0.7804, 0.7804, 0.7882]]])

## **DataLoaders**

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)

## **HyperModule**

In [8]:
!gdown https://github.com/kevinkevin556/Dlchemist/raw/main/hypermodule/hypermodule.py
from hypermodule import HyperModule

Downloading...
From: https://github.com/kevinkevin556/Dlchemist/raw/main/hypermodule/hypermodule.py
To: /content/hypermodule.py
  0% 0.00/2.07k [00:00<?, ?B/s]8.60kB [00:00, 15.6MB/s]       


## **Network Architecture**

<img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png"  width="70%">

**There are several puzzle to be solved in this figure:**
1. Where does the input size 576x576 come from?
2. How to implement up-convolution?


* 2x2 Transposed Convolution (stride 2)
  * [milesial/Pytorch-UNet](https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py) ![](https://img.shields.io/github/stars/milesial/Pytorch-UNet?style=social)
  * [yassouali/pytorch-segmentation](https://github.com/yassouali/pytorch-segmentation/blob/588d074507377f4ff2ae33a4df3d911ae2840315/models/unet.py) ![](https://img.shields.io/github/stars/yassouali/pytorch-segmentation?style=social)
  * [meetps/pytorch-semseg](https://github.com/meetps/pytorch-semseg/blob/801fb200547caa5b0d91b8dde56b837da029f746/ptsemseg/models/unet.py) ![](https://img.shields.io/github/stars/meetps/pytorch-semseg?style=social)
* Upsample + 3x3 Convolution (padding 1)
  * [LeeJunHyun/Image_Segmentation](https://github.com/LeeJunHyun/Image_Segmentation/blob/db34de21767859e035aee143c59954fa0d94bbcd/network.py) ![](https://img.shields.io/github/stars/LeeJunHyun/Image_Segmentation?style=social)
* Upsample + 1x1 Convolution
  * [jvanvugt/pytorch-unet](https://github.com/jvanvugt/pytorch-unet/blob/master/unet.py) ![](https://img.shields.io/github/stars/jvanvugt/pytorch-unet?style=social)

<div class="alert alert-block alert-info">
<b>Tip:</b> Use blue boxes (alert-info) for tips and notes. 
If it’s a note, you don’t have to include the word “Note”.
</div>

In [32]:
from torch import cat
from torch.nn import Sequential,Conv2d, ReLU, MaxPool2d, ConvTranspose2d
from torchvision.transforms import CenterCrop
from torch.nn.functional import interpolate


class UpSample(nn.Module):
    def __init__(self, size=None, scale_factor=None, mode='nearest'):
        super().__init__()
        self.size, self.scale_factor = size, scale_factor
        self.mode = mode

    def forward(self, x):
        if self.size == '2d+1':
          size = tuple([2*d+1 for d in x.shape[2:]])
          return interpolate(x, size=size, mode=self.mode)
        else:
          return interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode)


class UpConv(nn.Module):
  def __init__(self, in_channels, out_channels, prac=0, relu=False):
      super().__init__()
      if prac == 0:
        self.net = Sequential(
            UpSample(size='2d+1', mode='bilinear'), 
            Conv2d(in_channels, out_channels, 2, 1)
        )
      elif prac == 1:
        self.net = Sequential(ConvTranspose2d(in_channels, out_channels, 2, 2))
      elif prac == 2:
        self.net = Sequential(
            UpSample(scale_factor=2, mode='bilinear'),
            Conv2d(in_channels, out_channels, 3, 1, 1)
          )
      elif prac == 3:
        self.net = Sequential(
          UpSample(scale_factor=2, mode='bilinear'),
          Conv2d(in_channels, out_channels, 1, 1, 0)
        )
      else:
          raise ValueError("No implementation to assigned value of argument 'prac'")
      
      if relu:
        self.net.add_module(ReLU())
  
  def forward(self, x):
      return self.net(x)


class LeftConv(nn.Module):
  def __init__(self, in_channels, out_channels):
      super().__init__()
      self.net = Sequential(
          Conv2d(in_channels, out_channels, kernel_size=3), ReLU(),
          Conv2d(out_channels, out_channels, kernel_size=3), ReLU()
      )
      self.maxpool = MaxPool2d(kernel_size=2, stride=2)
      

  def forward(self, input):
      input = self.net(input)
      out_downward = self.maxpool(input)
      out_rightward = input   # copy
      return out_downward, out_rightward


class RightConv(nn.Module):
  def __init__(self, in_channels, out_channels):
      super().__init__()
      self.net = Sequential(
          Conv2d(in_channels, out_channels, kernel_size=3), ReLU(),
          Conv2d(out_channels, out_channels, kernel_size=3), ReLU()
      )
      self.up_conv = UpConv(in_channels, out_channels)
    
  def forward(self, bottom_in, left_in):
      bottom_in = self.up_conv(bottom_in)
      
      n, c, h, w = bottom_in.shape
      crop = CenterCrop((h, w))
      left_in = crop(left_in)
      
      input = cat((left_in, bottom_in), dim=1)
      out_upward = self.net(input)
      return out_upward


class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.left1, self.right1 = LeftConv(1, 64),  RightConv(64*2, 64)
        self.left2, self.right2 = LeftConv(64, 128),  RightConv(128*2, 128)
        self.left3, self.right3 = LeftConv(128, 256), RightConv(256*2, 256)
        self.left4, self.right4 = LeftConv(256, 512), RightConv(512*2, 512)
        self.bottom = Sequential(
            Conv2d(512, 1024, kernel_size=3),   ReLU(),
            Conv2d(1024, 1024, kernel_size=3),  ReLU(),
        )
        self.conv_out = Conv2d(64, 2, kernel_size=1)
        self.net = Sequential(
            self.left1, self.left2, self.left3, self.left4,
            self.bottom, 
            self.right4, self.right3, self.right2, self.right1,
            self.conv_out
        )
    
    def forward(self, x):
        d1, r1 = self.left1(x)
        d2, r2 = self.left2(d1)
        d3, r3 = self.left3(d2)
        d4, r4 = self.left4(d3)
        
        u5 = self.block_bottom(d4)
        
        u4 = self.right4(r4, u5)
        u3 = self.right3(r3, u4)
        u2 = self.right2(r2, u3)
        u1 = self.right1(r1, u2)

        out = self.conv_out(u1)
        return out
    
    def __repr__(self):
        return self.net.__repr__()

In [34]:
X = torch.randn(1, 1,576, 576)
model = Unet()
expansive_out = []
print(f"{'Image shape:':25s} {list(X.shape)}")

for layer in model.net:
    if type(layer) is LeftConv:
      X, R = layer(X)
      expansive_out.append(R)
    elif type(layer) is RightConv:
      L = expansive_out.pop()
      X = layer(X, L)
    else:
      X = layer(X)
    print(f"{layer.__class__.__name__+' output shape:':25s} {list(X.shape)}")

del X, model

Image shape:              [1, 1, 576, 576]
LeftConv output shape:    [1, 64, 286, 286]
LeftConv output shape:    [1, 128, 141, 141]
LeftConv output shape:    [1, 256, 68, 68]
LeftConv output shape:    [1, 512, 32, 32]
Sequential output shape:  [1, 1024, 28, 28]
RightConv output shape:   [1, 512, 52, 52]
RightConv output shape:   [1, 256, 100, 100]
RightConv output shape:   [1, 128, 196, 196]
RightConv output shape:   [1, 64, 388, 388]
Conv2d output shape:      [1, 2, 388, 388]


# Reference

* A [Youtube](https://www.youtube.com/watch?v=poY_nGzEEWM&ab_channel=Computerphile) clip from the channel **Computerphile** explains the notion about bicubic interpolation clearly.

* https://towardsdatascience.com/understanding-u-net-61276b10f360