## ViT Architecture

In [1]:
import torch
import math
from torch import nn, Tensor
from einops import rearrange, repeat


### Utilities

In [2]:
class RearrangeLayer(nn.Module):
    def __init__(self, pattern, **dims):
        """
        Custom layer for einops.rearrange.
        
        Args:
            pattern (str): Rearrangement pattern.
        """
        super().__init__()
        self.pattern = pattern
        self.dims = dims

    def forward(self, x):
        return rearrange(x, self.pattern, **self.dims)

## Patch Embedding

**Positional Encoding:**

In [3]:


class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, max_len: int = 5000):
        """
        Sinusoidal Positional Encoding Module.

        Args:
            emb_size (int): The size of the embedding dimension.
            max_len (int): The maximum length of the sequence.
        """
        super(PositionalEncoding, self).__init__()

        # Create a positional encoding matrix of shape (max_len, emb_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # Shape: (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, emb_size, 2).float() * (-math.log(10000.0) / emb_size)
        )  # Scaling factor for even indices

        # Compute sinusoidal values for even indices and cosines for odd indices
        pe = torch.zeros(max_len, emb_size)
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices

        # Add a batch dimension and register as a buffer (non-trainable parameter)
        self.register_buffer("positional_encoding", pe.unsqueeze(0))  # Shape: (1, max_len, emb_size)

    def forward(self, x: Tensor) -> Tensor:
        """
        Add positional encoding to the input tensor.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, emb_size).

        Returns:
            Tensor: Positional encoded tensor of the same shape as input.
        """
        seq_len = x.size(1)
        return x + self.positional_encoding[:, :seq_len, :]
    

    

**Patch Embedding with Conv2d:**

In [4]:


class PatchEmbedding(nn.Module):
    """
    
    Uses a COnv2 layer with stride=patch size to create the patch embeddings
    
    Args:
        in_channels (int): Number of channels in the input image.
        patch_size (int): The size of the patch to extract.
        embed_size (int): The size of the embedding dimension.
        img_size (int): SIze of the input iamge.
    """
    def __init__(self, in_channels: int, patch_size: int, embed_size: int, img_size: int):
        self.patch_size=patch_size
        super().__init__()
        self.embed = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, 
                      out_channels=embed_size, 
                      kernel_size=patch_size, 
                      stride=patch_size),
            # Rearrange (batch, embedding_dimensions, height, width) to (batch, height*width, embeding_dims)
            RearrangeLayer('b e h w -> b (h w) e'),
        )

        # class token
        #self.cls_token = nn.Parameter(torch.rand(1, 1, embed_size))

        # positional encoding
        self.pos_embed = PositionalEncoding(embed_size, (img_size // patch_size)**2)

    def forward(self, x: Tensor) -> Tensor:
        """
        Creates patch embedding of input iamge

        Args:
            x (Tensor): Input tensor of shape (batch_size, in_channel, image_size, image_size).

        Returns:
            Tensor: Conv2d with positional embedding output of shape ().
        """
        # get batch size
        b, _, _, _, = x.shape

        # forward pass through conv2d
        x = self.embed(x)
        print(f'Conv2d output shape: {x.shape}')

        # add positional encoding to each vector
        x = self.pos_embed(x)
        print(f'Conv2d with positional embedding output shape: {x.shape}')

        
        return x



### Test Patch Embedding Output

In [5]:
model = PatchEmbedding(in_channels=3, patch_size=16, embed_size=768, img_size=224)
model.eval()

# Dummy input image (batch_size=2, channels=3, height=224, width=224)
dummy_input = torch.randn(4, 3, 224, 224)
embeddings = model(dummy_input)

print("Embeddings shape:", embeddings.shape)  # Expected: (batch_size, num_patches, embed_dim)

Conv2d output shape: torch.Size([4, 196, 768])
Conv2d with positional embedding output shape: torch.Size([4, 196, 768])
Embeddings shape: torch.Size([4, 196, 768])


## CNN Embedding

In [6]:
from torchvision.models import resnet50, ResNet50_Weights

class CNNFeatureEmbedder(nn.Module):
    def __init__(self, cnn_backbone=None, embed_dim=768, position_embedding=True):
        """
        CNN-based embedding module for hybrid architectures.

        Args:
            cnn_backbone (nn.Module): CNN model for feature extraction. Default: ResNet-50.
            embed_dim (int): Target embedding dimension for the Transformer.
            position_embedding (bool): Whether to add positional embeddings to the sequence.
        """
        super().__init__()

        # CNN Backbone
        if cnn_backbone is None:
            cnn_backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            self.feature_extractor = nn.Sequential(*list(cnn_backbone.children())[:-2])  # Up to final conv layer
        else:
            self.feature_extractor = cnn_backbone

        # Patch embedding projection
        self.patch_embed = nn.Linear(cnn_backbone.fc.in_features, embed_dim)

        # Positional embedding (to be computed dynamically based on feature map size)
        self.position_embedding = position_embedding
        self.embed_dim = embed_dim

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass for the CNN-based embedding module.

        Args:
            x (Tensor): Input image tensor of shape (batch_size, channels, height, width).

        Returns:
            Tensor: Sequence of patch embeddings of shape (batch_size, num_patches, embed_dim).
        """
        # Extract CNN feature maps
        print("\nCNN")
        features = self.feature_extractor(x)  # Shape: (batch_size, C, H, W)
        print(f"Features shape: {features.shape}")

        # Flatten spatial dimensions
        batch_size, C, H, W = features.shape
        print(f"Feature maps: {H}x{W}")
        num_patches = H * W
        print(f"Number of patches: {num_patches}")
        features = features.permute(0, 2, 3, 1).reshape(batch_size, num_patches, C)  # Shape: (batch_size, num_patches, C)
        print(f"Flattened features shape: {features.shape}")

        # Project to Transformer embedding dimension
        embeddings = self.patch_embed(features)  # Shape: (batch_size, num_patches, embed_dim)
        print(f"Featues shape after linear projection to embed_dim={self.embed_dim}: {embeddings.shape}")
        # Add positional embedding
        if self.position_embedding:
            pos_embed = PositionalEncoding(self.embed_dim, max_len=num_patches)
            embeddings = pos_embed(embeddings)  # Shape: (batch_size, num_patches, embed_dim)

        return embeddings

### Test Embedding Output Shape

In [7]:
model = CNNFeatureEmbedder(embed_dim=768)
model.eval()

dummy_input = torch.randn(4, 3, 224, 224)
embeddings = model(dummy_input)

print("Embeddings shape:", embeddings.shape)  
print(f"Expected: (batch_size=2, num_patches=49, embed_dim=768)")




CNN
Features shape: torch.Size([4, 2048, 7, 7])
Feature maps: 7x7
Number of patches: 49
Flattened features shape: torch.Size([4, 49, 2048])
Featues shape after linear projection to embed_dim=768: torch.Size([4, 49, 768])
Embeddings shape: torch.Size([4, 49, 768])
Expected: (batch_size=2, num_patches=49, embed_dim=768)


## Transformer Block

In [8]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4*embed_dim),
            nn.GELU(),
            nn.Linear(4*embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        print(f"\nTransformerBlock")
        residual1 = x # first residual
        print(f'residual1 shape: {residual1.shape}')

        x = self.layer_norm(x) # first layer norm
        print(f'self attention input shape: {x.shape}')
        x, _ = self.mha(x, x, x) # MultiHeaded Attention
        print(f'self attention output shape: {x.shape}')
        x = self.dropout(x) + residual1 # First dropout layer w/ skip connection

        residual2 = x
        print(f'residual2 shape: {residual1.shape}')
        x = self.layer_norm(x) # Second Layernorm
        print(f'mlp input shape: {x.shape}')
        x = self.feed_forward(x) # MLP with GeLU
        print(f'mlp output shape: {x.shape}')
        x = self.dropout(x) + residual2  # Second Dropout w/ skip connection  
        
        return x



In [9]:
def calc_output(input_size):
    stride = 2
    padding = 1
    kernel_size = 4
    output = (input_size-1) * stride - 2 * padding + kernel_size
    print(f"Size = {output}")
    return output

input = 7
while input <= 224:
    input = calc_output(input)

Size = 14
Size = 28
Size = 56
Size = 112
Size = 224
Size = 448


In [10]:

class SimpleDecoder(nn.Module):
    def __init__(self, in_channels=768):
        """
        Upsample ResNet feature map back to input size using transpose convolutions.

        Output Size = (Input Size - 1) x (Stride - 2) x (Padding + Kernel Size)


        Args:
            in_channels (int): Number of input channels from ResNet feature map.
            target_size (tuple): Target spatial size (height, width) to upsample to.
        """
        super().__init__()
        
        # Transpose convolution layers
        self.upsample = nn.Sequential(
            RearrangeLayer('b (h w) c -> b c h w', h=7, w=7), # reshape (batch, size, embed_dim) -> (batch, channel, height, width)
            nn.ConvTranspose2d(in_channels, 512, kernel_size=4, stride=2, padding=1),  # 7 -> 14
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),          # 14 -> 28
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),          # 28 -> 56
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),           # 56 -> 112
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),             # 112 -> 224
        )

    def forward(self, x):
        """
        Forward pass to upsample feature map to input size.
        Args:
            x (Tensor): Input feature map from ResNet of shape (batch_size, in_channels, height, width).
        
        Returns:
            Tensor: Upsampled feature map of shape (batch_size, 3, target_size[0], target_size[1]).
        """
        print(f"\nTransposeCNN input: {x.shape}")
        output =  self.upsample(x)
        print(f"TransposeCNN output: {output.shape}")
        return output


### Test Decoder Output

In [11]:
input = torch.randn(4, 49, 768) 

decoded = SimpleDecoder(in_channels=768)
decoded_image = decoded(input)  

print("Feature map shape:", input.shape)
print("Decoded image shape:", decoded_image.shape)


TransposeCNN input: torch.Size([4, 49, 768])
TransposeCNN output: torch.Size([4, 3, 224, 224])
Feature map shape: torch.Size([4, 49, 768])
Decoded image shape: torch.Size([4, 3, 224, 224])


## Transformer Architecture

In [12]:
class VisionTransformer(nn.Module):
  def __init__(self, num_layers, img_size, embed_dim, patch_size, num_head, cnn_embedding=True):
    super().__init__()
    # patch embedding

    if cnn_embedding == True:
      #Input embedding
      self.patch_emb = CNNFeatureEmbedder(embed_dim=embed_dim)
      #Output decoder
      self.output_layer = SimpleDecoder(in_channels=embed_dim)
    else:
      # Input embedding
      self.patch_emb = PatchEmbedding(in_channels=3, patch_size=patch_size, img_size=img_size, embed_size=embed_dim)
      # Output reshaping
      self.output_layer = nn.Sequential(RearrangeLayer('b (h w) (patch_c ph pw) -> b patch_c (h ph) (w pw)', 
                         h=14, w=14, patch_c=3, ph=16, pw=16))

    # Transformer layers
    self.trans_encoder = nn.Sequential(*[TransformerEncoderBlock(embed_dim, num_head) for layer in range(num_layers)])

    

  def forward(self, x): # input: [b, c, h, w]
    print(f"input shape: {x.shape}")
    # Get patch embeddings w/ positional encoding
    patch_embeddings = self.patch_emb(x) # patch_embeddings: [b, (h*w), e]
    print(f'patch_embeddings shape: {patch_embeddings.shape}')

    # Transformer encoding layers
    x = self.trans_encoder(patch_embeddings) # x: [b, (h*w), e]
    print(f'transformer output shape: {x.shape}')
    
    # reshape embedding into image
    output_img = self.output_layer(x) #output_img: [b, c, h, w]

    return output_img


In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_layers = 2
embed_dim = 768
num_head = 12
patch_size=16
img_size=224
model = VisionTransformer( num_layers=num_layers,
                            img_size=img_size,
                            embed_dim=embed_dim,
                            patch_size=patch_size,
                            num_head=num_head
                            ).to(device)

In [14]:
model

VisionTransformer(
  (patch_emb): CNNFeatureEmbedder(
    (feature_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (rel

## Check Outputs

In [15]:
# Test input
batch_size = 4
test_input = torch.rand(batch_size, 3, img_size, img_size).to(device)

# Forward pass
try:
    print("Model Input shape:", test_input.shape)
    output = model(test_input)
    print("\nModel  sOutput shape:", output.shape)
except Exception as e:
    print(f"An error occurred during the forward pass: {e}")

Model Input shape: torch.Size([4, 3, 224, 224])
input shape: torch.Size([4, 3, 224, 224])

CNN
Features shape: torch.Size([4, 2048, 7, 7])
Feature maps: 7x7
Number of patches: 49
Flattened features shape: torch.Size([4, 49, 2048])
Featues shape after linear projection to embed_dim=768: torch.Size([4, 49, 768])
patch_embeddings shape: torch.Size([4, 49, 768])

TransformerBlock
residual1 shape: torch.Size([4, 49, 768])
self attention input shape: torch.Size([4, 49, 768])
self attention output shape: torch.Size([4, 49, 768])
residual2 shape: torch.Size([4, 49, 768])
mlp input shape: torch.Size([4, 49, 768])
mlp output shape: torch.Size([4, 49, 768])

TransformerBlock
residual1 shape: torch.Size([4, 49, 768])
self attention input shape: torch.Size([4, 49, 768])
self attention output shape: torch.Size([4, 49, 768])
residual2 shape: torch.Size([4, 49, 768])
mlp input shape: torch.Size([4, 49, 768])
mlp output shape: torch.Size([4, 49, 768])
transformer output shape: torch.Size([4, 49, 768])


In [4]:
import logging

class CNNFeatureEmbedder(nn.Module):
    """
    CNN-based embedding module for hybrid architectures.

    Args:
        cnn_backbone (nn.Module): CNN model for feature extraction. Default: ResNet-50.
        embed_dim (int): Target embedding dimension for the Transformer.
        position_embedding (bool): Whether to add positional embeddings to the sequence.
    """
    def __init__(self, pretrained=True, embed_dim=768, position_embedding=True):
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.info(f"Initializing CNNFeatureEmbedder with embed_dim={embed_dim}, position_embedding={position_embedding}")
        
        if pretrained:
            cnn_backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            self.feature_extractor = nn.Sequential(*list(cnn_backbone.children())[:-2])
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        else:
            cnn_backbone = resnet50()
            self.feature_extractor = nn.Sequential(*list(cnn_backbone.children())[:-2])

        self.patch_embed = nn.Linear(cnn_backbone.fc.in_features, embed_dim)
        self.position_embedding = position_embedding
        self.embed_dim = embed_dim

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass for the CNN-based embedding module.

        Args:
            x (Tensor): Input image tensor of shape (batch_size, channels, height, width).

        Returns:
            Tensor: Sequence of patch embeddings of shape (batch_size, num_patches, embed_dim).
        """
        self.logger.debug(f"Input image tensor shape: {x.shape}")
        features = self.feature_extractor(x)
        self.logger.debug(f"Features shape: {features.shape}")
        
        batch_size, C, H, W = features.shape
        self.logger.debug(f"Feature map dims: {H}x{W}")
        num_patches = H * W
        self.logger.debug(f"Number of patches: {num_patches}")
        
        features = features.permute(0, 2, 3, 1).reshape(batch_size, num_patches, C)
        self.logger.debug(f"Flattened features shape: {features.shape}")

        embeddings = self.patch_embed(features)
        self.logger.debug(f"Embedding shape after projection: {embeddings.shape}")
        
        if self.position_embedding:
            pos_embed = PositionalEncoding(self.embed_dim, max_len=num_patches)
            embeddings = pos_embed(embeddings)
            
        return embeddings

In [None]:
from torchvision.models import resnet50, ResNet50_Weights

model = CNNFeatureEmbedder(pretrained=False, position_embedding=False, embed_dim=3)
model.eval()

dummy_input = torch.randn(4, 3, 224, 224)
embeddings = model(dummy_input)

print("Embeddings shape:", embeddings.shape)  
print(f"Expected: (batch_size=2, num_patches=49, embed_dim=768)")




Embeddings shape: torch.Size([4, 49, 768])
Expected: (batch_size=2, num_patches=49, embed_dim=768)
