In [90]:
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms


class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=10):
        super().__init__()
        self.img_size = img_size
        print(f"IMAGE_SIZE----{self.img_size}")

        self.patch_size = patch_size
        print(f"PATCH SIZE------{self.patch_size}")

        self.n_patches = (img_size // patch_size) ** 2
        print(F"N_PARAMETERS----{self.n_patches}")


        # Check for valid patch size
        if img_size % patch_size != 0:
            raise ValueError("Image size must be divisible by patch size.")

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        print(f"PROJECT LAYER----- {self.proj}")

    def forward(self, x):
        B, C, H, W = x.shape  # Get the input shape
        print(f"FORWARD LAYER---- {H, W}")
        
        # Check for valid input dimensions
        if H != self.img_size or W != self.img_size:
            raise ValueError(f"Input image size ({H}x{W}) doesn't match model's img_size ({self.img_size}).")

        x = self.proj(x)  # (B, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        print(f"PROJECTION AFTER FORWARD PASS {x.shape}")
        x = x.flatten(2)
        print(f"AFTER FLATTENING {x.shape}")
        x= x.transpose(1, 2)  # (B, n_patches, embed_dim)
        print(f"AFTER TRANSPOSE {x.shape}")
        print("\n\n")

        return x


In [91]:
class PositionalEncoding(nn.Module):
    def __init__(self, n_patches ,embed_dim):
        super().__init__()
        
        # learnable postional embedding with class token
        self.positional_embedding = nn.Parameter(torch.randn(1,n_patches+1,embed_dim))
        print(f"postional embedding {self.positional_embedding.shape}")

        # learnable class token
        self.class_token = nn.Parameter(torch.randn(1,1,embed_dim))
        print(f"class token {self.class_token.shape}")

    def forward(self,x):
        batch_size = x.shape[0]
        print(f"Batch Size {batch_size}")
        cls_token = self.class_token.expand(batch_size,-1 ,-1)
        print(f"Class token {cls_token.shape}")

        # concate class token 
        x = torch.cat([cls_token,x],dim = 1)
        print(f"Concatinate class roken {x.shape}")
        x = x + self.positional_embedding
        print(x.shape)
        return x 


In [92]:
# Combining Patch Embedding and Positional Encoding
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, 
                 in_chans=3, embed_dim=10,):
        super().__init__()
        
        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=in_chans, 
            embed_dim=embed_dim
        )
        
        # Positional Encoding
        self.pos_encoding = PositionalEncoding(
            n_patches=(img_size//patch_size)**2, 
            embed_dim=embed_dim
        )

    
    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)
        
        # Positional Encoding
        x = self.pos_encoding(x)
        
        # Further processing would happen here
        return x


In [93]:
# Hyperparameters
img_size = 224
patch_size = 4
embed_dim = 10

# Create a random input image tensor
input_image = torch.randn(4, 3, img_size, img_size)  # 4 images in batch

# Initialize Vision Transformer
vit = VisionTransformer(
    img_size=img_size, 
    patch_size=patch_size, 
    embed_dim=embed_dim
)

# Forward pass
output = vit(input_image)

# print("Input Image Shape:", input_image.shape)
# print("Output Shape:", output.shape)

IMAGE_SIZE----224
PATCH SIZE------4
N_PARAMETERS----3136
PROJECT LAYER----- Conv2d(3, 10, kernel_size=(4, 4), stride=(4, 4))
postional embedding torch.Size([1, 3137, 10])
class token torch.Size([1, 1, 10])
FORWARD LAYER---- (224, 224)
PROJECTION AFTER FORWARD PASS torch.Size([4, 10, 56, 56])
AFTER FLATTENING torch.Size([4, 10, 3136])
AFTER TRANSPOSE torch.Size([4, 3136, 10])



Batch Size 4
Class token torch.Size([4, 1, 10])
Concatinate class roken torch.Size([4, 3137, 10])
torch.Size([4, 3137, 10])


In [99]:
import torch
import torch.nn as nn

class SimplePositionalEncoding(nn.Module):
    def __init__(self, n_tokens, embed_dim):
        super().__init__()
        self.positional_embeddings = nn.Parameter(torch.randn(1, n_tokens + 1, embed_dim)) # +1 for class token
        print(f"POSITIONAL EMBEDDINGS {self.positional_embeddings} \n")
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        print(f"CLASS TOKEN {self.class_token} \n")

    def forward(self, x):
        batch_size = x.shape[0]
        print(f"BATCH SIZE {batch_size} \n")
        cls_token = self.class_token.expand(batch_size, -1, -1)  # Expand class token for the batch
        print(f"CLASS TOKEN {cls_token} \n")
        x = torch.cat([cls_token, x], dim=1)  # Concatenate class token
        print(f"CONCAT USING TORCH {x}")
        return x + self.positional_embeddings

# Example usage:
n_tokens = 4  # Example: 4 patches
embed_dim = 3  # Example embedding dimension
x = torch.randn(2, n_tokens, embed_dim)  # Batch size of 2
print(f"PATCH SIZE {x} \n")
pos_enc = SimplePositionalEncoding(n_tokens, embed_dim)
encoded_x = pos_enc(x)
print(f"ENCODED X {encoded_x} \n")
print(encoded_x.shape)

PATCH SIZE tensor([[[ 1.0334, -0.6626,  0.2378],
         [ 0.0790, -0.3910,  0.2376],
         [ 0.6799,  1.2311, -0.0737],
         [ 0.5852,  0.8770,  1.2304]],

        [[-1.2126, -2.8163,  1.8042],
         [ 0.7125,  0.4960,  0.5178],
         [ 0.2315,  2.6094, -0.2609],
         [-0.0798,  1.3035,  0.0351]]]) 

POSITIONAL EMBEDDINGS Parameter containing:
tensor([[[ 0.6438, -1.9903, -0.3504],
         [ 0.0607,  0.7025,  0.7990],
         [-0.5644,  0.1264, -0.1117],
         [ 0.2809, -1.9468,  0.0443],
         [ 0.8796,  1.5137, -0.0539]]], requires_grad=True) 

CLASS TOKEN Parameter containing:
tensor([[[ 1.0111,  0.3674, -0.5016]]], requires_grad=True) 

BATCH SIZE 2 

CLASS TOKEN tensor([[[ 1.0111,  0.3674, -0.5016]],

        [[ 1.0111,  0.3674, -0.5016]]], grad_fn=<ExpandBackward0>) 

CONCAT USING TORCH tensor([[[ 1.0111,  0.3674, -0.5016],
         [ 1.0334, -0.6626,  0.2378],
         [ 0.0790, -0.3910,  0.2376],
         [ 0.6799,  1.2311, -0.0737],
         [ 0.5852,