In [1]:
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from load_data.prepared_custom_ds import CustomDataset
from utilities.config_load import load_config

In [2]:
CONFIG_PATH = "configs/"
config = load_config(CONFIG_PATH, "config.yaml")

In [3]:
config

{'dataset': {'train_img_path': 'dataset/train_v2',
  'test_img_path': 'dataset/test_v2',
  'mask_path': 'dataset/train_masks',
  'reshaped_img_path': 'dataset/reshaped_img',
  'dir_path': 'dataset'},
 'original_img_size': 768,
 'new_img_size': 256,
 'project_path': 'C:/Users/da4nik/Segmentation',
 'model': None}

In [None]:
def ConvBlock(first_chanels, second_chanels, kernel_size, dropout_rate):
    return nn.Sequential(
        nn.BatchNorm2d(first_chanels),
        nn.Conv2d(first_chanels, second_chanels, kernel_size, padding='same'),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout_rate),
        nn.BatchNorm2d(second_chanels),
        nn.Conv2d(second_chanels, second_chanels, kernel_size, padding='same'),
        nn.ReLU(inplace=True)
    )

In [None]:
class Unet_Encoder(nn.Module):
    def __init__(self, kernel_size, dropout_rate, nkernels):
        super(Unet_Encoder, self).__init___()
        self.kernel_size = kernel_size
        self.dropuut_rate = dropuut_rate
        self.nkernels = nkernels
        self.conv1 = ConvBlock(3, nkernels, self.kernel_size, self.dropout_rate)
        self.conv2 = ConvBlock(nkernels, nkernels*2, self.kernel_size, self.dropout_rate)
        self.conv3 = ConvBlock(nkernels*2, nkernels*4, self.kernel_size, self.dropout_rate)
        self.conv4 = ConvBlock(nkernels*4, nkernels*8, self.kernel_size, self.dropout_rate)
        self.maxpool_list = nn.ModuleList([nn.MaxPool2d(kernel_size=2) for _ in range(4)])
        self.conv_list = nn.ModuleList([self.conv1, self.conv2, self.conv3, self.conv4])

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.01)

    def forward(self, input):
        list_skips = list()
        for i in range(4):
            skip = self.self.conv_list[i](input)
            input = self.maxpool_list[i](skip1)
            list_skips.append(skip)
        return list_skips, input

In [None]:
class Unet_Decoder(nn.Module):
    def __init__(self, kernel_size, dropout_rate, nkernels):
        super(Unet_Decoder, self).__init___()
        self.kernel_size = kernel_size
        self.dropuut_rate = dropuut_rate
        self.nkernels = nkernels
        self.conv5 = ConvBlock(nkernels*8, nkernels*16, self.kernel_size, self.dropout_rate)
        self.conv6 = ConvBlock(nkernels*16, nkernels*8, self.kernel_size, self.dropout_rate)
        self.conv7 = ConvBlock(nkernels*8, nkernels*4, self.kernel_size, self.dropout_rate)
        self.conv8 = ConvBlock(nkernels*4, nkernels*2, self.kernel_size, self.dropout_rate)
        self.conv_list = nn.ModuleList([self.conv5, self.conv6, self.conv7, self.conv8])
        self.convt_list = nn.ModuleList([nn.ConvTranspose2d(nkernels*(2**(4-i)), nkernels*((2**(4-i))//2), kernel_size=(2, 2), stride=(2, 2)) 
                                           for i in range(4)])

    
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.01)

    
    def forward(self, input, list_skips):
        for i in range(4):
            if i==0:
                out = self.conv_list[i](input)
                out = self.convt_list[i](out)
            else:
                out = self.conv_list[i](torch.cat((out, list_skips[4-i]), 0))
                out = self.convt_list[i](out)
        return out

In [11]:
for i in range(4):
    print(str(2**(4-i)) + " => " + str((2**(4-i))//2))

16 => 8
8 => 4
4 => 2
2 => 1


In [None]:
class Model_Unet(nn.Module):
    def __init__(self, num_layers, kernel_size, dropout_rate, nkernels):
        super(Model_Unet, self).__init___()
        self.kernel_size = kernel_size
        self.dropout_rate = dropout_rate
        self.conv1 = ConvBlock(3, nkernels, self.kernel_size, self.dropout_rate)
        self.conv2 = ConvBlock(nkernels, nkernels*2, self.kernel_size, self.dropout_rate)
        self.conv3 = ConvBlock(nkernels*2, nkernels*4, self.kernel_size, self.dropout_rate)
        self.conv4 = ConvBlock(nkernels*4, nkernels*8, self.kernel_size, self.dropout_rate)
        self.conv5 = ConvBlock(nkernels*8, nkernels*16, self.kernel_size, self.dropout_rate)
        self.maxpool_list = nn.ModuleList([nn.MaxPool2d(kernel_size=2) for _ in range(4)])
        """self.maxpool1 = nn.MaxPool2d(2)
        self.maxpool2 = nn.MaxPool2d(2)
        self.maxpool3 = nn.MaxPool2d(2)
        self.maxpool4 = nn.MaxPool2d(2)"""
        self.conv6 = ConvBlock(nkernels*16, nkernels*8, self.kernel_size, self.dropout_rate)
        self.conv7 = ConvBlock(nkernels*8, nkernels*4, self.kernel_size, self.dropout_rate)
        self.conv8 = ConvBlock(nkernels*4, nkernels*2, self.kernel_size, self.dropout_rate)
        self.conv9 = ConvBlock(nkernels*2, nkernels, self.kernel_size, self.dropout_rate)
        self.conv10 = Last(...)
        # nn.ModuleList

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.01)

    def forward(self, input):
        skip1 = self.conv1(input)
        out = self.maxpool_list[0](skip1)
        skip2 = self.conv2(out)
        out = self.maxpool_list[1](skip2)
        skip3 = self.conv3(out)
        out = self.maxpool_list[2](skip3)
        skip4 = self.conv4(out)
        out = self.maxpool_list[3](skip4)
        down = self.conv5(out)
        up1 = ...
        add = self.conv6(torch.cat((up, skip4), 0))
        return pass

In [3]:
max = nn.MaxPool2d(2)

In [4]:
ct = nn.ConvTranspose2d(16, 16//2, kernel_size=(2, 2), stride=(2, 2))

In [10]:
res = max(torch.zeros(16, 200, 200))

In [11]:
res.shape

torch.Size([16, 100, 100])

In [5]:
test = ct(torch.zeros(16, 200, 200))

In [6]:
test.shape

torch.Size([8, 400, 400])