In [1]:
import timm
import torch
import numpy as np
import torch.nn as nn
#Similar to what they did in the CoAt net paper we will also make 

In [2]:
class Hybrid_embed(nn.Module):
    def __init__(self, feature_model, img_size, channels, efn_blocks, dims):
        super().__init__()
        
        
        self.feature_extractor = timm.create_model(feature_model,
                                                   features_only=True,
                                                   out_indices=[efn_blocks])
        
        
        self.feature_extractor.conv_stem = nn.Conv2d(3,   
                                       40,
                                       kernel_size=(3, 3),
                                       stride=(4, 4),
                                       padding=(1, 1),
                                       bias=False)
        
        with torch.no_grad():
                # NOTE Most reliable way of determining output dims is to run forward pass
                training = self.feature_extractor.training
                if training:
                    self.feature_extractor.eval()
                o = self.feature_extractor(torch.zeros(1, channels, img_size[0], img_size[1]))
                self.channel_output = o[0].shape[1]
                self.feature_extractor.train(training)
        
        self.embed_matcher = nn.Sequential(
            nn.Conv2d(self.channel_output, dims, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.BatchNorm2d(dims, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU(inplace=True)
        )
        
        
    def forward(self, x):
        x = self.feature_extractor(x)
        if isinstance(x, (list, tuple)):
            x = x[-1]  # last feature if backbone outputs list/tuple of features
        x = self.embed_matcher(x)
        x = x.flatten(2).transpose(1, 2)
        return x

In [5]:
class Hybrid_swin_effnet(nn.Module):
    def __init__(self, feature_model = "efficientnet_b3",img_size = (224,224), channels = 3, efn_blocks = 2, swin_blocks = 2, no_classes = 1):
        super().__init__()
        assert efn_blocks + swin_blocks == 4,f"The total no of blocks must be 4, instead {efn_blocks+swin_blocks} blocks provided "
        
        self.swin_backbone = timm.create_model("swin_tiny_patch4_window7_224")
        
        self.embeded_dim = self.swin_backbone.embed_dim * (2**(4 - swin_blocks))

        self.swin_backbone.patch_embed = Hybrid_embed(feature_model = "efficientnet_b3",
                                                      img_size = (224,224),
                                                      channels = 3,
                                                      efn_blocks = 2, 
                                                      dims = self.embeded_dim)
        
        #setting the first few blocks of swin to Indentity to match size
        for i in range((4- swin_blocks)):
            self.swin_backbone.layers[i] = nn.Identity()
        
        #Setting the head as per our need
        self.swin_backbone.head = nn.Linear(self.swin_backbone.num_features, no_classes)
        
    def forward(self, image):
        return self.swin_backbone(image).squeeze()

In [6]:
hybrid = Hybrid_swin_effnet()
sample = torch.randn(5, 3, 224, 224)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [7]:
hybrid(sample).shape

torch.Size([5])