In [1]:
import torch

In [2]:
class DoubleConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.step = torch.nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU(),
                                        torch.nn.Conv2d(out_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU())

    def forward(self, X):
        return self.step(X)

In [3]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.layer1 = DoubleConv(1, 64)
        self.layer2 = DoubleConv(64, 128)
        self.layer3 = DoubleConv(128, 256)
        self.layer4 = DoubleConv(256, 512)

        self.layer5 = DoubleConv(512+256, 256)
        self.layer6 = DoubleConv(256+128, 128)
        self.layer7 = DoubleConv(128+64, 64)
        self.layer8 = torch.nn.Conv2d(64, 1, 1)

        self.maxpool = torch.nn.MaxPool2d(2)

    def forward(self, x):
        x1 = self.layer1(x)     # 256*256*64
        x1m = self.maxpool(x1)  # 128*128*64

        x2 = self.layer2(x1m)   # 128*128*128
        x2m = self.maxpool(x2)  # 64*64*128

        x3 = self.layer3(x2m)   # 64*64*256
        x3m = self.maxpool(x3)  # 32*32*256

        x4 = self.layer4(x3m)   # 32*32*512

        x5 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x4)  # 64*64*512
        x5 = torch.cat([x5, x3], dim=1)                              # 64*64*(512+256)
        x5 = self.layer5(x5)                                         # 64*64*256

        x6 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x5)  # 128*128*256
        x6 = torch.cat([x6, x2], dim=1)                              # 128*128*(256+128)
        x6 = self.layer6(x6)                                         # 128*128*128

        x7 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x6)  # 256*256*128
        x7 = torch.cat([x7, x1], dim=1)                              # 256*256*(128+64)
        x7 = self.layer7(x7)                                         # 256*256*64

        ret = self.layer8(x7)                                        # 256*256*1
        return ret

In [4]:
model = UNet()

In [8]:
random_input = torch.randn(1, 1, 256, 256)
output = model(random_input)
"""
if output.shape == torch.Size([1, 1, 256, 256]):
    print("size match")
"""
assert output.shape == torch.Size([1, 1, 256, 256])