In [1]:
import torch
import torch.nn as nn
from torch.nn import Sequential, Conv2d, ReLU, MaxPool2d

In [2]:
image = torch.rand((1,1,512,512))

In [3]:
class Unet(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, img):
        ########### DOWN ###########
        x0 = self.down_conv(1,64)(img)
        print("x0:", x0.size())

        x1 = self.down_layer(64,128)(x0)
        print("x1:", x1.size())

        x2 = self.down_layer(128,256)(x1)
        print("x2:", x2.size())

        x3 = self.down_layer(256,512)(x2)
        print("x3:", x3.size())

        x4 = self.down_layer(512,1024)(x3)
        print("x4:", x4.size())

        ############ UP ############
        y3 = self.up_layer(x4, x3, 1024, 512)
        print("y3:", y3.size())

        y2 = self.up_layer(y3, x2, 512, 256)
        print("y2:", y2.size())

        y1 = self.up_layer(y2, x1, 256, 128)
        print("y1:", y1.size())
    
        y0 = self.up_layer(y1, x0, 128, 64)
        print("y0:", y0.size())

        res = Conv2d(64, 2, kernel_size=1, padding=0)(y0)
        print("r:", res.size())


    def down_conv(self, cin, cout):
        return Sequential(Conv2d(cin, cout, kernel_size=3, padding=1),ReLU(),
                          Conv2d(cout,cout, kernel_size=3, padding=1), ReLU())

    def down_layer(self, cin, cout):
        return Sequential(MaxPool2d(kernel_size=2, stride=2),
                        self.down_conv(cin,cout))
        
    def up_conv(self, cin, cout):
        return nn.ConvTranspose2d(cin, cout, kernel_size=2, stride=2)
    
    def up_layer(self, up, across, cin, cout):
        y = self.up_conv(cin, cout)(up)
        combined = torch.cat([across, y], 1)
        return self.down_conv(cin, cout)(combined)
    # def up_layer(self, cin, cout)

In [4]:
u = Unet()
u(image)

x0: torch.Size([1, 64, 512, 512])
x1: torch.Size([1, 128, 256, 256])
x2: torch.Size([1, 256, 128, 128])
x3: torch.Size([1, 512, 64, 64])
x4: torch.Size([1, 1024, 32, 32])
y3: torch.Size([1, 512, 64, 64])
y2: torch.Size([1, 256, 128, 128])
y1: torch.Size([1, 128, 256, 256])
y0: torch.Size([1, 64, 512, 512])
r: torch.Size([1, 2, 512, 512])


In [5]:
print(u)

Unet()
