In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

## Template for custom Model

In [None]:
class MyModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        pass

    def forward(self, x):
        pass

## Timm Library

In [76]:
class ConvNext(nn.Module):
    '''
     'convnext_base',
     'convnext_base_384_in22ft1k',
     'convnext_base_in22ft1k',
     'convnext_base_in22k',
     'convnext_large',
     'convnext_large_384_in22ft1k',
     'convnext_large_in22ft1k',
     'convnext_large_in22k',
     'convnext_small',
     'convnext_tiny',
     'convnext_xlarge_384_in22ft1k',
     'convnext_xlarge_in22ft1k',
     'convnext_xlarge_in22k'
    '''
    def __init__(self, num_classes):
        super().__init__()
        self.model = timm.create_model('convnext_large', pretrained=True, num_classes=num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x

In [77]:
conv = ConvNext(10)

Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" to /opt/ml/.cache/torch/hub/checkpoints/convnext_large_1k_224_ema.pth


In [78]:
print(conv)

ConvNext(
  (model): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (fc2): Linear(in_features=768, out_features=192, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), 

In [5]:
all_densenet_models = timm.list_models('*coa*', pretrained = True)
all_densenet_models

['coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny']

In [65]:
class CoatLiteMini(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = timm.create_model('coat_lite_mini', pretrained=True, num_classes=num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x


class Efficientnet_B4(nn.Module):
    def __init__(self, num_classes):
        super(Efficientnet_B4, self).__init__()
        self.model = timm.create_model('efficientnet_b4', pretrained=True, num_classes=num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x

class Efficientnet_B0(nn.Module):
    def __init__(self, num_classes):
        super(Efficientnet_B0, self).__init__()
        self.model = timm.create_model('efficientnet_b3a', pretrained=True, num_classes=num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x

class VitBase(nn.Module):
    def __init__(self, num_classes):
        super(Efficientnet_B0, self).__init__()
        self.model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)

    def forward(self, x):
        x = self.model(x)
        return x 

class VitLarge(nn.Module):
    def __init__(self, num_classes):
        super(VitLarge, self).__init__()
        self.model = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
class SWSResnext(nn.Module):
    def __init__(self, num_classes):
        super(SWSRexnext, self).__init__()
        self.model = timm.create_model('swsl_resnext50_32x4d', pretrained = True, num_classes = num_classes)
        
    def forward(self, x):
        x = self.model
        return x
    
class SWSResnext(nn.Module):
    def __init__(self, num_classes):
        super(SWSRexnext, self).__init__()
        self.model = timm.create_model('swsl_resnext50_32x4d', pretrained = True, num_classes = num_classes)
        
    def forward(self, x):
        x = self.model
        return x
    
class Mobilenet(nn.Module):
    def __init__(self, num_classes):
        super(Mobilenet, self).__init__()
        self.model = timm.create_model('mobilenetv2_100', pretrained = True, num_classes = num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
class SwinLarge(nn.Module):
    def __init__(self, num_classes):
        super(SwinLarge, self).__init__()
        self.model = timm.create_model('swin_large_patch4_window7_224', pretrained = True, num_classes = num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x

class CaiT(nn.Module):
    def __init__(self, num_classes):
        super(CaiT, self).__init__()
        self.model = timm.create_model('cait_s24_224', pretrained=True, num_classes=num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x 

In [67]:
## timm models Test
# model = Efficientnet_B4(10)
# model.eval()
# model(torch.randn(1,3,224,224))

## torchvision models

In [68]:
class DenseNet201(nn.Module):
    def __init__(self, num_classes):
        super(DenseNet201, self).__init__()
        self.model = models.densenet201(pretrained = True)
        self._change_last_layer(num_classes)
        
    def _change_last_layer(self, num_classes):
        name_last_layer = list(self.model.named_modules())[-1][0]
        
        if name_last_layer == 'classifier':
            self.model.classifier = nn.Linear(in_features = self.model.classifier.in_features,
                                              out_features = num_classes, bias = True)
            self._initialize_weights(self.model.classifier)
        
        elif name_last_layer == 'fc':
            self.model.fc = nn.Linear(in_features = self.model.fc.in_features,
                                      out_features = num_classes, bias = True)
            self._initialize_weights(self.model.fc)
        else:
            raise Exception('last layer should be either fc or classifier with nn.Linear Module')
        
    def _initialize_weights(self, m):
                            
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)
                            
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
    
    def forward(self, x):
        x = self.model(x)
        return x

In [44]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu') 
    
def change_last_layer(model, num_classes):
    name_last_layer = list(model.named_modules())[-1][0]
    
    if name_last_layer == 'classifier':
        model.classifier = nn.Linear(in_features = model.classifier.in_features,
                                              out_features = num_classes, bias = True)
        initialize_weights(model.classifier)
        return model
    
    elif name_last_layer == 'fc':
        model.fc = nn.Linear(in_features = model.fc.in_features,
                                      out_features = num_classes, bias = True)
        initialize_weights(model.fc)
        return model
    
    else:
        raise Exceptionception('last layer should be nn.Linear Module named as either fc or classifier')

In [53]:
class DenseNet161(nn.Module):
    def __init__(self, num_classes):
        super(DenseNet161, self).__init__()
        self.model = models.densenet161(pretrained = True)
        self.model = change_last_layer(self.model, num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x

class DenseNet121(nn.Module):
    def __init__(self, num_classes):
        super(DenseNet121, self).__init__()
        self.model = models.densenet121(pretrained = True)
        self.model = change_last_layer(self.model, num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x
    
class InceptionV3(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = models.inception_v3(pretrained = True)
        self.model = change_last_layer(self.model, num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x
    
class Resnet152(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = models.resnet152(pretrained = True)
        self.model = change_last_layer(self.model, num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x
    
class ResNext(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = models.resnext50_32x4d(pretrained = True)
        self.model = change_last_layer(self.model, num_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x
    

In [63]:
## Pytorchvision models Test
# model2 = ResNext(10)
# model2.eval()
# model2(torch.randn(1,3,224,224))