In [11]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (224 // patch_size) ** 2  # Assuming input image size is 224x224
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

    def forward(self, x):
        # Input x is an image tensor of shape (batch_size, channels, height, width)
        # Extract patches from the input image tensor using conv2d
        x = self.proj(x)
        # Reshape the extracted patches into a 2D tensor
        # (batch_size, num_patches, embedding_dim)
        x = x.flatten(2).transpose(1, 2)
        # Add positional embeddings to the patch embeddings
        x = x + self.pos_embed
        return x

In [12]:
from PIL import Image
from torchvision import transforms

# Load the image and resize it to 224x224
image = Image.open('data/dog.jpeg').convert('RGB')
resized_image = image.resize((224, 224))

# Convert the resized image to a tensor
tensor_image = transforms.ToTensor()(resized_image)

# Add a batch dimension to the tensor
batched_tensor_image = tensor_image.unsqueeze(0)

# Initialize the PatchEmbedding module with patch_size=16, embed_dim=768
patch_embedding = PatchEmbedding(patch_size=16, embed_dim=768)

# Compute patch embeddings for the input image tensor
patch_embeddings = patch_embedding.forward(batched_tensor_image)

# Print the shape of the output tensor
print(patch_embeddings.shape)

torch.Size([1, 196, 768])


In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1):
        super().__init__()

        # Multi-head self-attention layer
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate)

        # Feedforward neural network layer
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout_rate)
        )

        # Layer normalization
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Compute multi-head self-attention
        x = x.transpose(0, 1) # (num_patches, batch_size, embed_dim)
        x, _ = self.self_attention(x, x, x)
        x = x.transpose(0, 1) # (batch_size, num_patches, embed_dim)

        # Apply residual connection and layer normalization
        x = self.norm(x + x)

        # Apply feedforward neural network
        x = self.feedforward(x)

        # Apply residual connection and layer normalization
        x = self.norm(x + x)

        return x

In [14]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, dropout_rate=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout_rate)
            for _ in range(num_layers)
        ])
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class ClassificationHead(nn.Module):
    def __init__(self, embed_dim, num_classes):
        super().__init__()
        self.proj = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        # Compute the mean of the patch embeddings
        x = x.mean(dim=1)
        # Project the patch embeddings to the class logits
        x = self.proj(x)
        # Apply softmax activation to get class probabilities
        x = nn.functional.softmax(x, dim=-1)
        return x

In [15]:
encoder = TransformerEncoder(num_layers=3, num_heads=12, embed_dim=768)
classhead = ClassificationHead(embed_dim=768, num_classes=10)

encoded = encoder.forward(patch_embeddings)
final = classhead.forward(encoded)

In [17]:
final.shape

torch.Size([1, 10])