# U-Net: Convolutional Networks for Biomedical Image Segmentation
 📎 [Link to Paper](https://arxiv.org/abs/1505.04597)

## ✍️ Authors

Olaf Ronneberger, Philipp Fischer, and Thomas Brox

## Difficulty
⚫⚫⚪⚪⚪

## 📝 Abstract

There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.

 ## 🏛️ Network Architecture
 ![U-Net Network Architecture](img/architecture.png)

## 🔥 PyTorch Implementation

### 📦 Imports

In [None]:
import torch
import torch.nn as nn
from copy import deepcopy

from torchvision.transforms import CenterCrop

> The network architecture is illustrated in Figure 1. It consists of a contracting path (left side) and and expansive path (right side)

In [None]:
class UNet(nn.Module):
    
    def __init__(self, contracting_path: nn.Module, expansive_path: nn.Module):
        super().__init__()
        self.contracting_path = contracting_path
        self.expansive_path = expansive_path
    
    def forward(self, x):
        x, feature_maps = self.contracting_path(x)
        return self.expansive_path(x, feature_maps)

### Contractive Path

Note that the contracting path (left side) has two outputs `x` and `feature maps`. `x` is the output of the last block of the `contracting_path` and `feature_maps` are the 4 feature maps that are needed for the expansive_path (gray arrows).

> The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), [...] At each downsampling step we double the number of feature channels.

In [None]:
class ContractivePath(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_blocks = nn.ModuleList([
            ConvBlock(in_channels=in_channels, out_channels=64),
            ConvBlock(in_channels=64, out_channels=128),
            ConvBlock(in_channels=128, out_channels=256),
            ConvBlock(in_channels=256, out_channels=512),
        ])
    
    def forward(self, x) -> list:
        feature_maps = []
        for i, block in enumerate(self.conv_blocks):
            x = block(x)
            feature_maps.append(x)
            x = self.pooling(x)
        
        return x, feature_maps[::-1]

> ... each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling.

In [None]:
class ConvBlock(nn.Module):
    
    def __init__(self, in_channels: int, out_channels: int , kernel_size: int = 3, stride: int = 1, padding: int = 0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)

        self.conv_block = nn.Sequential(
            self.conv1,
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            self.conv2,
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv_block(x)

### Expansive Path

> Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU.

In [None]:
class ExpansivePath(nn.Module):
    
    def __init__(self, n_classes):
        super().__init__()
        self.conv_blocks = nn.ModuleList([
            ConvBlock(in_channels=512, out_channels=1024),
            ConvBlock(in_channels=1024, out_channels=512),
            ConvBlock(in_channels=512, out_channels=256),
            ConvBlock(in_channels=256, out_channels=128),
            ConvBlock(in_channels=128, out_channels=64),
        ])

        self.up_convs = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2),
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=n_classes, kernel_size=1, stride=1)
        ])

            
    def forward(self, x, feature_maps):
        
        x = self.conv_blocks[0](x)
        x = self.up_convs[0](x)

        for block, up_conv, feature_map in zip(self.conv_blocks[1:], self.up_convs[1:], feature_maps):
            size = feature_map.size()[-2:] # [batch size ,channels, height, width]
            # crop feature map to match the size of x
            #feature_map = CenterCrop(size)(feature_map)
            x = nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=True)
            # concatenate feature map and x
            x = torch.cat((feature_map, x), dim=1)
            # pass through conv block and up conv
            x = block(x)
            x = up_conv(x)
        return x

> At the final layer a 1x1 convolution is used to map each 64component feature vector to the desired number of classes. In total the network has 23 convolutional layers.

## 🏗️ Creating the Network

In [None]:
# This function creates a model for U-Net architecture
# It takes the number of input channels and the number of classes as parameters
# It returns an instance of the U-Net model
def make_model(in_channels: int, n_classes: int) -> nn.Module:
    contractive_path = ContractivePath(in_channels = in_channels)
    expansive_path = ExpansivePath(n_classes = n_classes)
    u_net = UNet(contractive_path, expansive_path)
    return u_net

In [None]:
# Create an instance of the U-Net model
u_net = make_model(in_channels = 3, n_classes = 1)

In [None]:
# generate pseudo batch of images
img = torch.randn(8, 3, 256, 256)

In [None]:
u_net(img).shape

## 💥 U-Net in Action!

In [None]:
from torchvision.datasets import VOCSegmentation
from torchvision import transforms

In [None]:
dataset = VOCSegmentation(
    root='data',
    year='2012',
    image_set='train',
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256, 256)),
        
        ]
    ),
    target_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((248, 248)),
        # 1 if pixel is > 0 else 0
        transforms.Lambda(lambda x: torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)))
    ])
)

In [None]:
from torch.utils.data import DataLoader
dl = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
from tqdm import tqdm

def train(model, dl, loss_fn, optimizer, device, epochs):
    model.train()
    model.to(device)
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')
        pbar = tqdm(dl)
        for i, batch in enumerate(pbar):
            # move batch to device
            batch = [b.to(device) for b in batch]
            # get images and labels
            images, labels = batch
            # zero the gradients
            optimizer.zero_grad()
            # forward pass
            images = images.to(device)
            preds = model(images)#.view(8, -1)
            #labels = labels.view(8, -1)
            # compute loss
            loss = loss_fn(preds, labels)
            # backward pass
            loss.backward()
            # update weights
            optimizer.step()
            # print loss
            pbar.set_description(f'Loss: {loss.item():.3f}')

In [None]:
train(u_net, dl, nn.BCEWithLogitsLoss(), torch.optim.Adam(u_net.parameters(), lr=0.001), 'cuda', 10)

In [None]:
# tensor to image
def tensor_to_image(tensor):
    image = transforms.ToPILImage()(tensor)
    return image

In [None]:
batch = next(iter(dl))

In [None]:
img = batch[0][0]
mask = batch[1][0]

In [None]:
tensor_to_image(img)

In [None]:
tensor_to_image(mask)

In [None]:
pred = u_net(img.to('cuda').unsqueeze(0)).sigmoid()[0]

In [None]:
tensor_to_image(pred)

In [None]:
tensor_to_image(torch.where(pred > 0.5, torch.ones_like(pred), torch.zeros_like(pred)))

In [None]:
# function to show on image as an overlay
def show_overlay(img, mask, alpha=0.5):
    import matplotlib.pyplot as plt
    import numpy as np
    # import cv2
    import cv2
    # convert to numpy
    img = img.detach().cpu().numpy()
    mask = mask.detach().cpu().numpy()
    # transpose to (height, width, channels)
    img = img.transpose(1, 2, 0)
    # transpose mask to (height, width, channels)
    mask = mask.transpose(1, 2, 0)
    # concat to additional channels of all zeros
    mask = np.concatenate([mask, np.zeros_like(mask), np.zeros_like(mask)], axis=-1)
    # resize mask
    mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    # create a figure
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    # show image
    ax.imshow(img)    
    ax.imshow(mask, alpha=alpha)

##  📚 References