In [1]:
# for processing image data we use fastai.vision
from fastai.vision.all import * 
# At the latest, we should definitely use the GPU for computing. Therefore, the very first thing we test is,
# if we have a kind CUDA device available
print(torch.cuda.get_device_name(0))
print(torch.__version__)

NVIDIA GeForce RTX 3070
2.3.0+cu121


In [2]:
# Random seed fixieren -> pseudo zufällig 
torch.manual_seed(0) # für pyTorch
random.seed(0)       # für python

based on https://arxiv.org/pdf/1505.04597

In [87]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, bottleneck_channels: int | None = None) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding="same")
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding="same")
        self.relu = nn.ReLU()

        self.seq_stack = nn.Sequential(
            self.conv1,
            self.relu,
            self.conv2,
            self.relu,
        )

        if bottleneck_channels is not None:
            self.seq_stack.append(nn.Conv2d(in_channels=out_channels, out_channels=bottleneck_channels, kernel_size=1))
            self.seq_stack.append(self.relu)

    def forward(self, x):
        x = self.seq_stack(x)
        return x

class CustomUnet(nn.Module):
    def __init__(self, in_channels, out_classes) -> None:
        super().__init__()
        self.double_conv_down_1 = DoubleConv(in_channels, 64)
        self.double_conv_down_2 = DoubleConv(64, 128)
        self.double_conv_down_3 = DoubleConv(128, 256)
        self.double_conv_down_4 = DoubleConv(256, 512)
        self.double_conv_down_5 = DoubleConv(512, 1024, bottleneck_channels=512)

        self.double_conv_up_4 = DoubleConv(1024, 512, bottleneck_channels=256)
        self.double_conv_up_3 = DoubleConv(512, 256, bottleneck_channels=128)
        self.double_conv_up_2 = DoubleConv(256, 128, bottleneck_channels=64)
        self.double_conv_up_1 = DoubleConv(128, 64)

        self.bottleneck = nn.Conv2d(in_channels=64, out_channels=out_classes, kernel_size=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x1 = self.double_conv_down_1(x)

        x2, indices_x2 = self.pool(x1)
        x2 = self.double_conv_down_2(x2)

        x3, indices_x3 = self.pool(x2)
        x3 = self.double_conv_down_3(x3)

        x4, indices_x4 = self.pool(x3)
        x4 = self.double_conv_down_4(x4)

        x5, indices_x5 = self.pool(x4)
        x5 = self.double_conv_down_5(x5)

        x5 = self.unpool(x5, indices_x5)
        x4 = torch.concat([x4, x5], dim=1)
        x4 = self.double_conv_up_4(x4)

        x4 = self.unpool(x4, indices_x4)
        x3 = torch.concat([x3, x4], dim=1)
        x3 = self.double_conv_up_3(x3)

        x3 = self.unpool(x3, indices_x3)
        x2 = torch.concat([x2, x3], dim=1)
        x2 = self.double_conv_up_2(x2)

        x2 = self.unpool(x2, indices_x2)
        x1 = torch.concat([x1, x2], dim=1)
        x1 = self.double_conv_up_1(x1)

        x = self.bottleneck(x1)
        return x

In [88]:
net = CustomUnet(3, 2)
x = torch.rand((4, 3, 224, 224))
pred = net(x)
print(pred.shape)

torch.Size([4, 2, 224, 224])
