In [5]:
%pip install torch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [11]:
import torch
from torch import nn
from torch import optim

In [67]:
class UNetContractingBlock(nn.Module):
    def __init__(self, input_channels, output_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # two back to back convolutional layers
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels,
                  kernel_size=(3, 3))
        self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                               kernel_size=(3, 3))
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

    def forward(self, x) -> torch.Tensor:
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)

        return x


In [87]:
class UNetExpandingBlock(nn.Module):

    def __init__(self, input_channels, output_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.upsample = nn.ConvTranspose2d(in_channels=input_channels, out_channels=output_channels,
                                        kernel_size=(2, 2), stride=2)
        self.conv1 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                               kernel_size=(3, 3))
        self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                               kernel_size=(3, 3))
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x

In [91]:
class UNet(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

        depth = 4

        # contracting path
        self.contracting_blocks = []
        start_out_channels = 64
        for i in range(depth):
            if i != 0:
                next_out_channels = start_out_channels * 2
                self.contracting_blocks.append(
                    UNetContractingBlock(input_channels=start_out_channels, output_channels=next_out_channels)
                )
                start_out_channels = next_out_channels
            else:
                self.contracting_blocks.append(
                    UNetContractingBlock(input_channels=1, output_channels=start_out_channels)
                )

        # intermediate conv block, no maxpool
        self.intermediate_conv = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3)),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3)),
        )

        # expanding path
        self.expanding_blocks = []
        start_in_channels = 1024
        for _ in range(depth):
            next_in_channels = start_in_channels // 2
            self.expanding_blocks.append(
                UNetExpandingBlock(input_channels=start_in_channels, output_channels=next_in_channels)
            )
            start_in_channels = next_in_channels

        # last convolution
        self.conv_last = nn.Conv2d(in_channels=64, out_channels=2,
                                   kernel_size=(1, 1,))

    def forward(self, x) -> torch.Tensor:
        print(self.contracting_blocks)
        print(self.expanding_blocks)
        for block in self.contracting_blocks:
            x = block(x)

        x = self.intermediate_conv(x)

        for block in self.expanding_blocks:
            x = block(x)

        x = self.conv_last(x)

        return x

In [92]:
model = UNet()

In [93]:
t = torch.rand(1, 1, 572, 572)  # batch size, channels, height, width
out = model(t)

type(out), out.shape, out

[UNetContractingBlock(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padd

(torch.Tensor,
 torch.Size([1, 2, 388, 388]),
 tensor([[[[-0.0152, -0.0153, -0.0153,  ..., -0.0153, -0.0153, -0.0153],
           [-0.0155, -0.0151, -0.0155,  ..., -0.0151, -0.0155, -0.0152],
           [-0.0152, -0.0153, -0.0152,  ..., -0.0153, -0.0152, -0.0153],
           ...,
           [-0.0155, -0.0151, -0.0155,  ..., -0.0151, -0.0155, -0.0152],
           [-0.0152, -0.0153, -0.0152,  ..., -0.0153, -0.0152, -0.0153],
           [-0.0155, -0.0151, -0.0155,  ..., -0.0151, -0.0155, -0.0151]],
 
          [[-0.1414, -0.1413, -0.1414,  ..., -0.1413, -0.1414, -0.1413],
           [-0.1414, -0.1415, -0.1414,  ..., -0.1415, -0.1414, -0.1415],
           [-0.1414, -0.1414, -0.1414,  ..., -0.1414, -0.1414, -0.1413],
           ...,
           [-0.1414, -0.1415, -0.1414,  ..., -0.1415, -0.1414, -0.1415],
           [-0.1414, -0.1414, -0.1414,  ..., -0.1414, -0.1414, -0.1413],
           [-0.1414, -0.1416, -0.1415,  ..., -0.1416, -0.1415, -0.1416]]]],
        grad_fn=<ConvolutionBackward0>))