# I. SwinTran

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model


class SwinTransformerModel(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(SwinTransformerModel, self).__init__()
        self.swin = create_model(
            'swin_large_patch4_window7_224.ms_in22k', 
            pretrained=True,
            num_classes=num_classes,
            in_chans=1
        )
        
        if not fine_tune:
            for param in self.swin.parameters():
                param.requires_grad = False
            
            for param in self.swin.head.parameters():
                param.requires_grad = True
        

    def forward(self, x):
        x = self.swin(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinTransformerModel(num_classes, fine_tune=False)
model.to(device)

print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output)
display_params_flops(model)

In [None]:
model = SwinTransformerModel(num_classes, fine_tune=False)
model.to(device)


loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.01)

# II. Unetr

In [None]:
import timm
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock


class DaViT_UnetR_Modelv2(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(DaViT_UnetR_Modelv2, self).__init__()
        
        self.davit = timm.create_model('davit_base.msft_in1k', pretrained=pretrained, features_only=True, in_chans=1)
        
        if not fine_tune:
            for param in self.davit.parameters():
                param.requires_grad = False
        
        
        spatial_dims = 2 
        in_channels = 1 # R,G,B
        feature_size = 128
        norm_name = "instance"
        hidden_size = 128
        res_block = True
        conv_block = False

        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*2,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*4,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size * 8,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        
        self.conv = nn.Sequential(
            nn.Conv2d(feature_size, 78, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(78, 50, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2450, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    
    def forward(self, x_in):
        
        hidden_states_out = self.davit(x_in)

        enc1 = self.encoder1(x_in)
        
        x2 = hidden_states_out[0]
        enc2 = self.encoder2(x2)
        
        x3 = hidden_states_out[1]
        enc3 = self.encoder3(x3)
        
        x4 = hidden_states_out[2]
        enc4 = self.encoder4(x4)

        dec4 = hidden_states_out[3]

        dec3 = self.decoder5(dec4, enc4)
        
        dec2 = self.decoder4(dec3, enc3)

        dec1 = self.decoder3(dec2, enc2)

        out = self.decoder2(dec1, enc1) 
 
        conv_out = self.conv(out)

        return self.classifier(conv_out)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DaViT_UnetR_Modelv2(num_classes, fine_tune=False)
model.to(device)

print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output)
display_params_flops(model)

In [None]:
model = DaViT_UnetR_Modelv2(num_classes, fine_tune=False)
model.to(device)


loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.01)

# III. CoAtNet

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model


class CoatNetModel(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(CoatNetModel, self).__init__()
        self.coatnet = create_model(
            'timm/coatnet_3_rw_224.sw_in12k',
            pretrained=True,
            num_classes=num_classes,
            in_chans=1
        )
        
        if not fine_tune:
            for param in self.coatnet.parameters():
                param.requires_grad = False
            
            for param in self.coatnet.head.parameters():
                param.requires_grad = True
        

    def forward(self, x):
        x = self.coatnet(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CoatNetModel(num_classes, fine_tune=False)
model.to(device)

print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output)
display_params_flops(model)

In [None]:
model = CoatNetModel(num_classes, fine_tune=False)
model.to(device)


loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), 3e-5)

# IV. VGG19

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class VGG19Model(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(VGG19Model, self).__init__()

        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)

        self.vgg19 = models.vgg19(pretrained=True)
        
        if not fine_tune:
            for param in self.vgg19.parameters():
                param.requires_grad = False

            for param in self.vgg19.classifier.parameters():
                param.requires_grad = True
            

        in_features = self.vgg19.classifier[6].in_features

        self.vgg19.classifier[6] = nn.Linear(in_features, num_classes)
        
        

    def forward(self, x):
        x = self.conv(x)
        x = self.vgg19(x)
        return x
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG19Model(num_classes, fine_tune=False)
model.to(device)

print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output)
display_params_flops(model)

In [None]:
model = VGG19Model(num_classes, fine_tune=False)
model.to(device)


loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), 3e-5)

# V. DenseNet

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model

class DenseNetModel(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(DenseNetModel, self).__init__()
        self.densenet = create_model(
            'densenet121.tv_in1k', 
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=1
        )

        if not fine_tune:
            for param in self.densenet.parameters():
                param.requires_grad = False
            
            for param in self.densenet.global_pool.parameters():
                param.requires_grad = True
                
            for param in self.densenet.head_drop.parameters():
                param.requires_grad = True
                
            for param in self.densenet.classifier.parameters():
                param.requires_grad = True
        
        
    def forward(self, x):
        x = self.densenet(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNetModel(num_classes, fine_tune=False)
model.to(device)

print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output)
display_params_flops(model)

In [None]:
model = DenseNetModel(num_classes, fine_tune=False)
model.to(device)

loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), 3e-5)