<a href="https://colab.research.google.com/github/nguyenanhtienabcd/AIO2024_EXERCISE/blob/feature%2FMODULE7-WEEK1/m07w01_ex1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

tạo feature map ban đầu từ

In [None]:
class FirstFeature(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(FirstFeature, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=input_channels,
                      out_channels=out_channels,
                      kernel_size=1,
                      padding=0,
                      stride=1,
                      bias=False),
            nn.LeakyReLU()
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=input_channels,
                      out_channels=out_channels,
                      kernel_size=3,
                      padding=1,
                      stride=1,
                      bias = False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(implace = True),
            nn.Conv2d(in_channels=input_channels,
                      out_channels=out_channels,
                      kernel_size=3,
                      padding=1,
                      stride=1,
                      bias = False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(implace = True),
        )

    def forward(self, x):
        return self.conv(x)


In [None]:
class Encoder(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Maxpool2d(kernel_size=2),
            ConvBlock(input_channels, out_channels)
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv = nn.Sequential(
            nn.UpsamplingBilinear2d(2),
            nn.Conv2d(input_channels = input_channels,
                      out_channels = out_channels,
                      kernel_size = 1,
                      stride = 1,
                      padding = 0,
                      bias = False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(implace = True),
        )
        self.conv_block = ConvBlock(input_channels, out_channels)
    def forward(self, x, skip):
        x = self.conv(x)
        x = torch.cat((x, skip), dim = 1)
        x = self.conv_block(x)
        return x

In [None]:
class FinalOutput(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(FinalOutput, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, out_channels, 1, 1, 0, bias = False),
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class Unet(nn.Module):
    def __init__(self, n_channels=3, n_classes=3):
        super(Unet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.first_feature = FirstFeature(n_channels, 64)

        self.encoder1 = Encoder(64, 128)
        self.encoder2 = Encoder(128,256)
        self.encoder3 = Encoder(256,512)
        self.encoder4 = Encoder(512,1024)

        self.decoder1 = Decoder(1024,512)
        self.decoder2 = Decoder(512,256)
        self.decoder3 = Decoder(256,128)
        self.decoder4 = Decoder(128,64)

        self.final_output = FinalOutput(64, n_classes)

    def forward(self, x):
      x1 = self.first_feature(x)
      x2 = self.encoder1(x1)
      x3 = self.encoder2(x2)
      x4 = self.encoder3(x3)
      x5 = self.encoder4(x4)

      x = self.decoder1(x5, x4)
      x = self.decoder2(x, x3)
      x = self.decoder3(x, x2)
      x = self.decoder4(x, x1)
      x = self.final_output(x)
      return x
