In [1]:
import torch
from torch import nn
from torchvision import transforms

## Homework 6 - U-Net
*Lorenzo Basile*

In [2]:
#Original UNet architecture
class contracting_module(torch.nn.Module):
        
    def __init__(self, index):
        super().__init__()
        self.output_channels=2**(6+index)
        self.layers=nn.Sequential(
            nn.MaxPool2d(kernel_size=2) if index!=0 else nn.Identity(),
            nn.Conv2d(in_channels=1 if index==0 else self.output_channels//2, out_channels=self.output_channels, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=3, stride=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.layers(x)

class expanding_module(torch.nn.Module):
        
    def __init__(self, index):
        super().__init__()
        self.output_channels=2**(6+index)
        self.upsampling=nn.ConvTranspose2d(in_channels=self.output_channels*2, out_channels=self.output_channels, kernel_size=2, stride=2)
        self.layers=nn.Sequential(
            nn.Conv2d(in_channels=self.output_channels*2, out_channels=self.output_channels, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=3, stride=1),
            nn.ReLU()
        )
    
    def forward(self, x, copy):
        x=self.upsampling(x)
        cropped_copy=transforms.CenterCrop(size=x.shape[-1])(copy)
        x=torch.cat([cropped_copy, x], dim=1)
        return self.layers(x)

class UNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.cm0=contracting_module(0)
        self.cm1=contracting_module(1)
        self.cm2=contracting_module(2)
        self.cm3=contracting_module(3)
        self.cm4=contracting_module(4)
        self.em0=expanding_module(0)
        self.em1=expanding_module(1)
        self.em2=expanding_module(2)
        self.em3=expanding_module(3)
        self.final_layer=nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
    def forward(self, x):
        copy_0=self.cm0(x)
        x=copy_0
        copy_1=self.cm1(x)
        x=copy_1
        copy_2=self.cm2(x)
        x=copy_2
        copy_3=self.cm3(x)
        x=copy_3
        x=self.cm4(x)
        x=self.em3(x, copy_3)
        x=self.em2(x, copy_2)
        x=self.em1(x, copy_1)
        x=self.em0(x, copy_0)
        x=self.final_layer(x)
        return x

In [3]:
net=UNet()
input=torch.rand((1,1,512,512))
print(input.shape)
print(net(input).shape)

torch.Size([1, 1, 512, 512])
torch.Size([1, 2, 324, 324])


In [4]:
#Requested architecture
class contracting_module(torch.nn.Module):
        
    def __init__(self, index):
        super().__init__()
        self.output_channels=2**(6+index)
        self.layers=nn.Sequential(
            nn.MaxPool2d(kernel_size=2) if index!=0 else nn.Identity(),
            nn.Conv2d(in_channels=1 if index==0 else self.output_channels//2, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.layers(x)

class expanding_module(torch.nn.Module):
        
    def __init__(self, index):
        super().__init__()
        self.output_channels=2**(6+index)
        self.upsampling=nn.ConvTranspose2d(in_channels=self.output_channels*2, out_channels=self.output_channels, kernel_size=2, stride=2)
        self.layers=nn.Sequential(
            nn.Conv2d(in_channels=self.output_channels*2, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
    
    def forward(self, x, copy):
        x=self.upsampling(x)
        cropped_copy=transforms.CenterCrop(size=x.shape[-1])(copy)
        x=torch.cat([cropped_copy, x], dim=1)
        return self.layers(x)

class UNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.cm0=contracting_module(0)
        self.cm1=contracting_module(1)
        self.cm2=contracting_module(2)
        self.cm3=contracting_module(3)
        self.cm4=contracting_module(4)
        self.em0=expanding_module(0)
        self.em1=expanding_module(1)
        self.em2=expanding_module(2)
        self.em3=expanding_module(3)
        self.final_layer=nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)
    def forward(self, x):
        copy_0=self.cm0(x)
        x=copy_0
        copy_1=self.cm1(x)
        x=copy_1
        copy_2=self.cm2(x)
        x=copy_2
        copy_3=self.cm3(x)
        x=copy_3
        x=self.cm4(x)
        x=self.em3(x, copy_3)
        x=self.em2(x, copy_2)
        x=self.em1(x, copy_1)
        x=self.em0(x, copy_0)
        x=self.final_layer(x)
        return x

In [5]:
net=UNet()
input=torch.rand((1,1,512,512))
print(input.shape)
print(net(input).shape)

torch.Size([1, 1, 512, 512])
torch.Size([1, 1, 512, 512])
