# U-Net

## Импорт библиотек

In [2]:
import torch
import torch.nn as nn
import gc

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Реализация U-Net

In [15]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
    
        self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), padding=1)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        out = self.act(x)

        return x

In [16]:
class DownSamle(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = DoubleConv(in_channels, out_channels)
        self.maxpooling = nn.MaxPool2d((2,2), stride=2)

    def forward(self, x):
        to_upsample = self.double_conv(x)
        to_downsample = self.maxpooling(to_upsample)

        return to_upsample, to_downsample     

In [17]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.convt = nn.ConvTranspose2d(in_channels, out_channels, (2, 2), stride=2)
        self.double_conv = DoubleConv(in_channels, out_channels)

    def forward(self, prev_sample, downsample):
        x = self.convt(prev_sample)
        x = torch.concat([x, downsample], dim=1)
        out = self.double_conv(x)

        return out     

In [20]:
class U_Net(nn.Module):
    def __init__(self, in_channels=3, num_classes=1):
        super().__init__()

        self.layer_down_1 = DownSamle(in_channels, 64)
        self.layer_down_2 = DownSamle(64, 128)
        self.layer_down_3 = DownSamle(128, 256)
        self.layer_down_4 = DownSamle(256, 512)

        self.conv1 = nn.Conv2d(512, 1024, (3, 3), padding=1)
        self.conv2 = nn.Conv2d(1024, 1024, (3, 3), padding=1)

        self.layer_up_4 = UpSample(1024, 512)
        self.layer_up_3 = UpSample(512, 256)
        self.layer_up_2 = UpSample(256, 128)
        self.layer_up_1 = UpSample(128, 64)

        self.conv_exit = nn.Conv2d(64, num_classes, (1, 1))

    def forward(self, x):
        to_lu1, ld1 = self.layer_down_1(x)
        to_lu2, ld2 = self.layer_down_2(ld1)
        to_lu3, ld3 = self.layer_down_3(ld2)
        to_lu4, ld4 = self.layer_down_4(ld3)

        c1 = self.conv1(ld4)
        c2 = self.conv2(c1)

        lu4 = self.layer_up_4(c2, to_lu4)
        lu3 = self.layer_up_3(lu4, to_lu3)
        lu2 = self.layer_up_2(lu3, to_lu2)
        lu1 = self.layer_up_1(lu2, to_lu1)

        out = self.conv_exit(lu1)

        return out

In [25]:
unet_model = U_Net(in_channels=1, num_classes=5).to(device)
unet_model

U_Net(
  (layer_down_1): DownSamle(
    (double_conv): DoubleConv(
      (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (maxpooling): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer_down_2): DownSamle(
    (double_conv): DoubleConv(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (maxpooling): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer_down_3): DownSamle(
    (double_conv): DoubleConv(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
   

In [26]:
inp = torch.rand([1, 1, 512, 512], dtype = torch.float32).to(device)
pred = unet_model(inp).to(device)
pred.shape

torch.Size([1, 5, 512, 512])