Coded by Lujia Zhong @lujiazho<br>
Reference: https://github.com/milesial/Pytorch-UNet

The weight map described in [U-Net](https://arxiv.org/pdf/1505.04597) is not adapted here, because it's pre-computed w.r.t cell segmentation labels.

In [1]:
import time
import torch
import torch.nn as nn

class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.max_pool = nn.MaxPool2d(2)
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0, bias=False)
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0, bias=False)
        self.act2 = nn.ReLU()

    def forward(self, x):
        assert x.shape[-2] % 2 == 0 and x.shape[-1] % 2 == 0, "Both H and W must be even."
        
        x = self.max_pool(x)
        
        x = self.conv1(x)
        x = self.act1(x)
        
        x = self.conv2(x)
        x = self.act2(x)
        
        return x

class UpSampling(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        # ConvTranspose2d has trainable kernels; UpSampling2D is interpolation: bilinear/nearest...
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0, bias=False)
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0, bias=False)
        self.act2 = nn.ReLU()

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        pad_h, pad_w = (x2.shape[-2]-x1.shape[-2])//2, (x2.shape[-1]-x1.shape[-1])//2
        x = torch.cat([x2[:,:,pad_h:-pad_h,pad_w:-pad_w], x1], dim=1)
        
        x = self.conv1(x)
        x = self.act1(x)
        
        x = self.conv2(x)
        x = self.act2(x)
        
        return x

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()

        self.in_conv = nn.Sequential(
            nn.Conv2d(n_channels, 64, kernel_size=3, padding=0, bias=False),
            nn.Conv2d(64, 64, kernel_size=3, padding=0, bias=False),
        )
        
        self.down1 = DownSampling(64, 128)
        self.down2 = DownSampling(128, 256)
        self.down3 = DownSampling(256, 512)
        self.down4 = DownSampling(512, 1024)
        
        self.up1 = UpSampling(1024, 512)
        self.up2 = UpSampling(512, 256)
        self.up3 = UpSampling(256, 128)
        self.up4 = UpSampling(128, 64)
        
        self.out_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # torch.Size([3, 776, 776])
        
        x = self.preprocess(x) # torch.Size([4, 3, 572, 572])
        
        x1 = self.in_conv(x)   # torch.Size([4, 64, 568, 568])
        
        x2 = self.down1(x1)    # torch.Size([4, 128, 280, 280])
        x3 = self.down2(x2)    # torch.Size([4, 256, 136, 136])
        x4 = self.down3(x3)    # torch.Size([4, 512, 64, 64])
        x5 = self.down4(x4)    # torch.Size([4, 1024, 28, 28])
        
        x = self.up1(x5, x4)   # torch.Size([4, 512, 52, 52])
        x = self.up2(x, x3)    # torch.Size([4, 256, 100, 100])
        x = self.up3(x, x2)    # torch.Size([4, 128, 196, 196])
        x = self.up4(x, x1)    # torch.Size([4, 64, 388, 388])
        
        x = self.out_conv(x)   # torch.Size([4, 10, 388, 388])
        
        logits = self.postprocess(x)   # torch.Size([10, 776, 776])
        
        return logits
    
    # pad and slice into patches
    def preprocess(self, x, slice_ = 2):
        ch, H, W = x.shape
        assert H % slice_ == 0 and W % slice_ == 0, "Cannot split imgs."
        
        pad = 92   # as in paper
        new_H, new_W = H // slice_, W // slice_
        
        x = nn.ReflectionPad2d(pad)(x)
        
        imgs = []
        for i in range(pad, H+pad, new_H):
            for j in range(pad, W+pad, new_W):
                imgs.append(x[:,i-pad:i+new_H+pad,j-pad:j+new_W+pad])
                
        x = torch.stack(imgs)
        return x
    
    def postprocess(self, x):
        patch_num, ch, H, W = x.shape
        n = int(patch_num**0.5)
                
        x = x.view(n, n, ch, H, W)
        x = x.permute(2, 0, 3, 1, 4).contiguous()
                
        return x.view(ch, H*n, W*n)
        

n_classes = 10
n_channels = 3
model = UNet(n_channels=n_channels, n_classes=n_classes)

In [2]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=0.99, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

iterarions = 1
begin = time.time()

# Training
for iterarion in range(iterarions):
    x = torch.rand(n_channels, 776, 776) # batch == 1
    y = torch.randint(0, n_classes, (1, 776, 776))
    
    optimizer.zero_grad()
    pred = model(x)
    # add dimension of batch for making torchscript happy
    loss = criterion(pred.unsqueeze(0), y)

    if iterarion % 1 == 0:
        print('Iterarion:', '%2d,' % (iterarion + 1), 'loss =', '{:.4f}'.format(loss))

    loss.backward()
    optimizer.step()
print(f"{(time.time() - begin)/iterarions:.4f}s / iterarion")

Iterarion:  1, loss = 2.3068
51.9413s / iterarion
