In [None]:
!pip install timm
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
import torch
import torch.nn as nn
import torch.nn.functional as F

Collecting timm
  Downloading timm-1.0.8-py3-none-any.whl.metadata (53 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->timm)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->timm)
  

In [None]:
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chan=3, embed_dim=768, multi_conv=False):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        if multi_conv:
            if patch_size[0] == 24:
                self.proj = nn.Sequential(
                    nn.Conv2d(in_chan, embed_dim // 4, kernel_size=7, stride=3, padding=2),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
                )
            elif patch_size[0] == 12:
                self.proj = nn.Sequential(
                    nn.Conv2d(in_chan, embed_dim // 4, kernel_size=7, stride=4, padding=3),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
                )
            elif patch_size[0] == 4:
                self.proj = nn.Sequential(
                    nn.Conv2d(in_chan, embed_dim // 4, kernel_size=7, stride=2, padding=3),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
                )
            else:
                raise ValueError(f"Unsupported patch size {patch_size[0]}")
        else:
            self.proj = nn.Conv2d(in_chan, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x, extra_padding=False):
        B, C, H, W = x.shape
        print(f"Input shape: {x.shape}")
        if extra_padding and (H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0):
            p_l = (self.patch_size[1] - W % self.patch_size[1]) // 2
            p_r = (self.patch_size[1] - W % self.patch_size[1]) - p_l
            p_t = (self.patch_size[0] - H % self.patch_size[0]) // 2
            p_b = (self.patch_size[0] - H % self.patch_size[0]) - p_t
            x = F.pad(x, (p_l, p_r, p_t, p_b))
            print(f"Padded shape: {x.shape}")

        for i, layer in enumerate(self.proj):
            x = layer(x)

        return x

# to run the code with example random tensor
img_size = 226
patch_sizes = [24, 12, 4]
in_chan = 3
embed_dim = 768

for patch_size in patch_sizes:
    print(f"\nTesting with patch size: {patch_size}")
    model = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chan=in_chan, embed_dim=embed_dim, multi_conv=True)
    x = torch.randn(1, in_chan, img_size, img_size)  # Example input tensor
    out = model(x, extra_padding=True)
    print(f"Final output shape for patch size {patch_size}: {out.shape}")



Testing with patch size: 24
Input shape: torch.Size([1, 3, 226, 226])
Padded shape: torch.Size([1, 3, 240, 240])
Final output shape for patch size 24: torch.Size([1, 768, 40, 40])

Testing with patch size: 12
Input shape: torch.Size([1, 3, 226, 226])
Padded shape: torch.Size([1, 3, 228, 228])
Final output shape for patch size 12: torch.Size([1, 768, 29, 29])

Testing with patch size: 4
Input shape: torch.Size([1, 3, 226, 226])
Padded shape: torch.Size([1, 3, 228, 228])
Final output shape for patch size 4: torch.Size([1, 768, 57, 57])


In [None]:
class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., patch_size=16):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.wq = nn.Linear(dim, dim, bias=qkv_bias)
        self.wk = nn.Linear(dim, dim, bias=qkv_bias)
        self.wv = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.patch_size = patch_size

    def forward(self, x):
        B, N, C = x.shape

        if self.patch_size == 24:
            q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  
        elif self.patch_size == 12:
            q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  
        else:
            q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  

        k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  
        v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  

        attn = (q @ k.transpose(-2, -1)) * self.scale  
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, 1, C) 
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [None]:
class CrossAttentionBlock(nn.Module):
  def __init__(self,dim,num_heads,mlp_ratio=4.,qkv_bias=False,qk_scale = None, drop = 0. , attn_drop = 0.,
               drop_path = 0. , act_layer = nn.GELU , norm_layer = nn.LayerNorm , has_mlp = True , patch_size=16):
     super().__init__()
     self.norm1 = norm_layer(dim)
     self.attn = CrossAttention(dim ,num_heads = num_heads , qkv_bias = qkv_bias, qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop , patch_size = patch_size)
     self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
     self.has_mlp = has_mlp
     if self.has_mlp:
      self.norm2 = norm_layer(dim)
      self.mlp = Mlp(in_features = dim ,hidden_features = mlp_hidden_dim , act_layer = act_layer , drop = drop)


  def forward(self , x):
    x = x[:, 0:1, ...]+self.drop_path(self.attn(self.norm(x)))
    if self.has_mlp:
      x= x+self.drop_path(self.mlp(self.norm2(x)))

    return x


In [3]:
class MultiScaleBlock(nn.Module):
  def __init__(self,dim,patches , depth , num_heads ,mlp_ratio , act_layer=nn.GELU, qkv_bias=False ,qk_scale = None, attn_drop = 0. , drop=0. ,norm_layer = nn.LayerNorm):
    super().__init__()

     #creating branches based on the embedding dim
    num_branches = len(dim)
    self.num_branches  = num_branches

    #transformer block for each branch
    self.blocks = nn.ModuleList()
    for d in range(num_branches):
      temp = []
      for i in range(depth[d]):
        temp.append(
            Block(dim = dim[d],num_heads = num_heads[d], mlp_ratio = mlp_ratio[d],
                 qkv_bias = qkv_bias,drop =drop,attn_drop= attn_drop,drop_path = drop_path[i],norm_layer = norm_layer)
        )
      if len(temp)!=0:
        self.blocks.append(nn.Sequential(*temp))
    if len(self.blocks) ==0:
      self.blocks=None

    #making sure that all the branches are of same size if so creats the projection layers
    self.proje = nn.ModuleList()
    for d in range(num_branches):
      temp = [norm_layer(dim[d]),act_layer(),nn.Linear(dim[d],dim[(d+1)%num_branches])]
    self.proje.append(nn.Sequential(*temp))



SyntaxError: incomplete input (<ipython-input-3-e87c552f97d9>, line 26)

In [None]:
import torch

img_size = 224
patch_sizes = [24, 12]
in_chan = 3
embed_dim = 768
num_heads = 8

for patch_size in patch_sizes:
    print(f"\nTesting with patch size: {patch_size}")

    patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chan=in_chan, embed_dim=embed_dim, multi_conv=True)

    x = torch.randn(1, in_chan, img_size, img_size)

    patch_embeddings = patch_embed(x, extra_padding=True)
    print(f"Patch embeddings shape: {patch_embeddings.shape}")

    # Reshape the output of PatchEmbed to match the expected input shape for CrossAttention
    # The output shape from PatchEmbed (B, embed_dim, H', W')
    #flatten the spatial dimensions to get the shape (B, N, C) for CrossAttention
    B, C, H_prime, W_prime = patch_embeddings.shape
    N = (H_prime * W_prime)  # Number of patches
    patch_embeddings = patch_embeddings.permute(0, 2, 3, 1).reshape(B, N, C)  

    cross_attention = CrossAttention(dim=embed_dim, num_heads=num_heads, patch_size=patch_size)

    attention_output = cross_attention(patch_embeddings)
    print(f"Attention output shape: {attention_output.shape}")


Testing with patch size: 24
Input shape: torch.Size([1, 3, 224, 224])
Padded shape: torch.Size([1, 3, 240, 240])
Patch embeddings shape: torch.Size([1, 768, 40, 40])
Attention output shape: torch.Size([1, 1, 768])

Testing with patch size: 12
Input shape: torch.Size([1, 3, 224, 224])
Padded shape: torch.Size([1, 3, 228, 228])
Patch embeddings shape: torch.Size([1, 768, 29, 29])
Attention output shape: torch.Size([1, 1, 768])
