<a href="https://colab.research.google.com/github/mishabar410/ML-2022/blob/main/TransAttUnet/Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install deeplake
import deeplake

In [None]:
# https://datasets.activeloop.ai/docs/ml/datasets/nih-chest-x-ray-dataset/
# https://datasets.activeloop.ai/docs/ml/datasets/chest-x-ray-image-dataset/
# https://datasets.activeloop.ai/docs/ml/datasets/glas-dataset/
# https://www.kaggle.com/datasets/nodoubttome/skin-cancer9-classesisic

In [21]:
! pip install -q kaggle
from google.colab import files
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle.json


In [22]:
from google.colab import files
files.upload()

{}

In [23]:
!kaggle datasets download nodoubttome/skin-cancer9-classesisic -p /content/sample_data/ --unzip

Downloading skin-cancer9-classesisic.zip to /content/sample_data
 97% 761M/786M [00:04<00:00, 148MB/s]
100% 786M/786M [00:04<00:00, 166MB/s]


In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)
        self.bnorm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3)
        self.bnorm2 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x1 = self.relu(self.bnorm1(self.conv1(x)))
        x2 = self.relu(self.bnorm2(self.conv2(x)))
        return x2

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.mpool = self.MaxPool2d(kernel_size = 3, stride = 2)

    def forward(self, x):
        return self.conv(self.mpool(x))

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        self.conv = DoubleConv(in_channels, out_channels)
        self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear')

    def forward(self, x1, x2):
        x1 = self.upsaple(x1)
        x1 = torch.cat([x1, x2], 1)
        return self.conv(x1)

In [None]:
class GSA(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.M_conv = nn.Conv2d(in_channels, in_channels // 8)
        self.N_conv = nn.Conv2d(in_channels, in_channels // 8)
        self.W_conv = nn.Conv2d(in_channels, in_channels)
        self.gamma = torch.parameter(torch.zeros(1))
    
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        M = self.M_conv(x).view(batch_size, -1, height * width)
        N = self.N_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        W = self.W_conv(x).view(batch_size, -1, height * width)
        B = F.softmax(torch.matmul(M, N)).permute(0, 2, 1).permute(0, 2, 1)
        result = torch.matmul(W, B).view(batch_size, channels, height, width)
        
        return self.gamma * result + x

In [None]:
class Position_Encoding(nn.Module):
    def __init__(self, num_pos_feats=256, len_embedding=32):
        super().__init__()
        self.row_embed = nn.Embeding(len_embedding, num_pos_feats)
        self.col_embed = nn.Embedding(len_embedding, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform(self.col_embed.weingt)

    def forward(self, tensor_list):
        x = tensor_list
        h, w = x.shape[-2:]
        i = torch.arange(w, device = x.device)
        j = torch.arange(h, device = x.device)

        x_emb = self.col_embed(i)
        y_emb = self.col_embed(j)

        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim = -1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)

        return pos

class Scalar_dot_product_attention(nn.Module):
    def __init__(self, coef):
        super().__init__()
        self.d_k = coef ** 0.5
        self.dropout = nn.Dropout(p = 0.1)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        q = x.view(batch_size, channels, -1)
        k = x.view(batch_size, channels, -1).permute(0, 2, 1)
        v = x.view(batch_size, channels, -1)
        attention = F.softmax(torch.matmul(q / self.d_k, k), -1)
        attention = self.dropout(attention)
        return torch.matmul(attention, v).view(batch_size, channels, height, width)

In [None]:
class Unet(nn.Module):
    def __init__(self, number_of_classes):
        super().__init__()
        self.first_conv = DoubleConv(3, 64)

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

        self.up1 = Up(1024, 512)
        self.up2 = Up(512 * 2, 256)
        self.up3 = Up(256 * 2, 128)
        self.up4 = Up(128 * 2, 64)
        self.up4 = Up(128* 2, 64)
        self.out = nn.Conv2d(64 * 2, number_of_classes, kernel_size = 3)

        self.pos = Position_Encoding(256)
        self.gsa = GSA(512)
        self.prod = Scalar_dot_product_attention(512)

    def forward(self, x):
        x1 = self.first_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x5_gsa = self.gsa(x5)
        x5 += self.pos(x5)
        x5_prod = self.prod(x5)
        x5 = x5_gsa + x5_prod

        x6 = self.up1(x5, x4)
        x5_scale = F.interpolate(x5, size = x6.shape[2:], mode = 'bilinear', align_corners=True)
        x6_cat = torch.cat((x6, x5_scale), 1)

        x7 = self.up2(x6_cat, x3)
        x6_scale = F.interpolate(x6, size = x7.shape[2:], mode = 'bilinear', align_corners=True)
        x7_cat = torch.cat((x7, x6_scale), 1)
        
        x8 = self.up3(x7_cat, x2)
        x7_scale = F.interpolate(x7, size = x8.shape[2:], mode = 'bilinear', align_corners=True)
        x8_cat = torch.cat((x8, x7_scale), 1)
        
        x9 = self.up4(x8_cat, x1)
        x8_scale = F.interpolate(x8, size = x9.shape[2:], mode = 'bilinear', align_corners=True)
        x9_cat = torch.cat((x9, x8_scale), 1)

        return self.out(x9_cat)