In [None]:
# Module imports
import torch
from torch import nn

In [None]:
def conv(in_channels, out_channels, kernel_size=3, padding=1):
    return torch.nn.Sequential(
        torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.ReLU(inplace=True),
    )

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        c = [16, 32, 64, 128]
        self.encoder = nn.Sequential(
            conv(in_channels, c[0]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[0], c[1]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[1], c[2]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[2], c[3]),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.decoder = nn.Sequential(
            conv(c[3], c[2]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[2], c[1]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[1], c[0]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[0], out_channels),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = nn.Sigmoid()(x)
        return x

model = UNet()