- https://arxiv.org/abs/1505.04597
    - U-Net: Convolutional Networks for Biomedical Image Segmentation


In [5]:
from IPython.display import Image
import torch
from torch import nn

## basics

- image input image output
    - 最开始要解决的是医学图像分割（medical image segmentation）
        - input：image, output：segmentation masks

    $$
    \text{MSE}=\frac1n\sum_{i=1}^n(Y_i - \hat Y_i)^2
    $$

    - 还可以用于图像超分辨率（high resolution）
    - diffusion models
- unet: encoder & decoder
    - encoder: extracting features from input images
    - decoder: up sampling intermediate features and producing the final output
    - encoder & decoder are symmetrical and connected by paths
        - U-shape Net

## architecture

- encoder
    - Repeated 3*3 conv + ReLU layers
        - 572 \* 572 -> 570 \* 570 -> 568 \* 568
    - 2*2 maxpooling layers to downsample
        - 568 -> 284
        - 280 -> 140
    - double channels with conv after maxpooling
        - 64 -> 128 -> 256
- decoder
    - repeated 3*3 conv + ReLU layers
    - Upsampling, followed by 2*2 conv layer
    - halve channels after upsampling conv
- connections: bottleneck & connecting paths
    - bottleneck
    - connecting paths：
        - 添加 encoder 的细节信息；
        - 实现上，在 depth/channels 上做拼接；


In [3]:
# https://arxiv.org/abs/1505.04597
Image(url='https://huggingface.co/blog/assets/78_annotated-diffusion/unet_architecture.jpg', width=700)

In [4]:
Image(url='https://camo.githubusercontent.com/2c676413dda1f487521dd5c1e5c4b35b8cfbf06d50880e15660ea44bd76eac6f/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f68756767696e67666163652f646f63756d656e746174696f6e2d696d616765732f7265736f6c76652f6d61696e2f756e65742d6d6f64656c2e706e67', 
              width=600)

## UNet from scratch

In [8]:
class BasicUNet(nn.Module):
    """A minimal UNet implementation."""
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([ 
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
        ])
        self.act = nn.SiLU() # The activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x)) # Through the layer and the activation function
            if i < 2: # For all but the third (final) down layer:
                h.append(x) # Storing output for skip connection
                x = self.downscale(x) # Downscale ready for the next layer
              
        for i, l in enumerate(self.up_layers):
            if i > 0: # For all except the first up layer
                x = self.upscale(x) # Upscale
                x += h.pop() # Fetching stored output (skip connection)
            x = self.act(l(x)) # Through the layer and the activation function
            
        return x

In [9]:
unet = BasicUNet()
x = torch.rand(8, 1, 28, 28)
unet(x).shape

torch.Size([8, 1, 28, 28])

In [17]:
# !pip install --upgrade torchsummary

- (Conv2d, silu, maxpool) \* 2
    - (1, 28, 28) => **(32, 28, 28)** => (32, 14, 14)
    - (32, 14, 14) => **(64, 14, 14)** => (64, 7, 7)
- (Conv2d, silu)
    - (64, 7, 7) => (64, 7, 7)
- (conv2d, silu)
    - (64, 7, 7) => (64, 7, 7)
- (upsample, Conv2d, silu) \* 2
    - (64, 7, 7) => **(64, 14, 14)** => (32, 14, 14)
    - (32, 14, 14) => **(32, 28, 28)** => (1, 28, 28)
    

In [18]:
sum([p.numel() for p in unet.parameters()])

309057

In [16]:
from torchsummary import summary
summary(unet, input_size=(1, 28, 28), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             832
              SiLU-2           [-1, 32, 28, 28]               0
         MaxPool2d-3           [-1, 32, 14, 14]               0
            Conv2d-4           [-1, 64, 14, 14]          51,264
              SiLU-5           [-1, 64, 14, 14]               0
         MaxPool2d-6             [-1, 64, 7, 7]               0
            Conv2d-7             [-1, 64, 7, 7]         102,464
              SiLU-8             [-1, 64, 7, 7]               0
            Conv2d-9             [-1, 64, 7, 7]         102,464
             SiLU-10             [-1, 64, 7, 7]               0
         Upsample-11           [-1, 64, 14, 14]               0
           Conv2d-12           [-1, 32, 14, 14]          51,232
             SiLU-13           [-1, 32, 14, 14]               0
         Upsample-14           [-1, 32,