In [35]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (224 // patch_size) ** 2  # Assuming input image size is 224x224
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))

    def forward(self, x):
        # Input x is an image tensor of shape (batch_size, channels, height, width)
        # Extract patches from the input image tensor using conv2d
        x = self.proj(x)
        # Reshape the extracted patches into a 2D tensor
        # (batch_size, num_patches, embedding_dim)
        x = x.flatten(2).transpose(1, 2)
        # Prepend the learnable class token to the patch embeddings
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        # Add positional embeddings to the patch and class embeddings
        x = x + self.pos_embed
        return x

In [36]:
from PIL import Image
from torchvision import transforms

# Load the image and resize it to 224x224
image = Image.open('/content/dog.jpeg').convert('RGB')
resized_image = image.resize((224, 224))

# Convert the resized image to a tensor
tensor_image = transforms.ToTensor()(resized_image)

# Add a batch dimension to the tensor
batched_tensor_image = tensor_image.unsqueeze(0)

# Initialize the PatchEmbedding module with patch_size=16, embed_dim=768
patch_embedding = PatchEmbedding(patch_size=16, embed_dim=768)

# Compute patch embeddings for the input image tensor
patch_embeddings = patch_embedding.forward(batched_tensor_image)

# Print the shape of the output tensor
print(patch_embeddings.shape)

torch.Size([1, 197, 768])


In [37]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1):
        super().__init__()

        # Multi-head self-attention layer
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate)

        # Feedforward neural network layer
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout_rate)
        )

        # Layer normalization
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Compute multi-head self-attention
        x = x.transpose(0, 1) # (num_patches, batch_size, embed_dim)

        # Apply residual connection and layer normalization
        x = self.norm(x + x)

        # Apply self attn
        x, _ = self.self_attention(x, x, x)
        x = x.transpose(0, 1) # (batch_size, num_patches, embed_dim)

        # Apply residual connection and layer normalization
        x = self.norm(x + x)

        # Apply feedforward neural network
        x = self.feedforward(x)

        return x

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, dropout_rate=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout_rate)
            for _ in range(num_layers)
        ])
        
    def forward(self, x):
        for layer in self.layers:
            x = layer.forward(x)
        return x

class ClassificationHead(nn.Module):
    def __init__(self, embed_dim, num_classes):
        super().__init__()
        self.proj = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        # Compute the mean of the patch embeddings
        x = x.mean(dim=1)
        # Project the patch embeddings to the class logits
        x = self.proj(x)
        # Apply softmax activation to get class probabilities
        x = nn.functional.softmax(x, dim=-1)
        return x

In [38]:
encoder = TransformerEncoder(num_layers=3, num_heads=12, embed_dim=768)
classhead = ClassificationHead(embed_dim=768, num_classes=10)

encoded = encoder.forward(patch_embeddings)
final = classhead.forward(encoded)

print(final.shape)

torch.Size([1, 10])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Set the device to run the model on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the transformation to be applied to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load the CIFAR-10 dataset and apply the transformation
train_dataset = datasets.CIFAR10('data', train=True, download=False, transform=transform)

# Set the batch size and create a data loader for the training set
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define the Vision Transformer model
class VisionTransformer(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, embed_dim=768, num_layers=3, num_heads=12, num_classes=10):
        super().__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim)
        self.transformer = TransformerEncoder(num_layers, embed_dim, num_heads)
        self.classification_head = ClassificationHead(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding.forward(x)
        x = self.transformer.forward(x)
        x = self.classification_head.forward(x)
        return x

# Create an instance of the Vision Transformer model
model = VisionTransformer().to(device)

# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=3e-5)
criterion = nn.CrossEntropyLoss()

# Define the number of epochs to train for
num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    # Set the model to training mode
    model.train()

    # Loop over the batches in the training set
    for i, (images, labels) in enumerate(train_loader):
        # Move the data to the device
        images = images.to(device)
        labels = labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model.forward(images)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization step
        loss.backward()
        optimizer.step()

        # Print training status every 10 batches
        if (i + 1) % 10 == 0:
          print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

Epoch [1/10], Step [10/782], Loss: 2.2916
Epoch [1/10], Step [20/782], Loss: 2.2746
Epoch [1/10], Step [30/782], Loss: 2.2419
Epoch [1/10], Step [40/782], Loss: 2.2335
Epoch [1/10], Step [50/782], Loss: 2.2548
Epoch [1/10], Step [60/782], Loss: 2.2256
Epoch [1/10], Step [70/782], Loss: 2.2465
Epoch [1/10], Step [80/782], Loss: 2.2247
Epoch [1/10], Step [90/782], Loss: 2.2296
Epoch [1/10], Step [100/782], Loss: 2.2542
Epoch [1/10], Step [110/782], Loss: 2.2397
Epoch [1/10], Step [120/782], Loss: 2.2794
Epoch [1/10], Step [130/782], Loss: 2.1918
Epoch [1/10], Step [140/782], Loss: 2.2422
Epoch [1/10], Step [150/782], Loss: 2.2834
Epoch [1/10], Step [160/782], Loss: 2.2662
Epoch [1/10], Step [170/782], Loss: 2.2454
Epoch [1/10], Step [180/782], Loss: 2.2562
Epoch [1/10], Step [190/782], Loss: 2.1463
Epoch [1/10], Step [200/782], Loss: 2.1478
Epoch [1/10], Step [210/782], Loss: 2.2507
Epoch [1/10], Step [220/782], Loss: 2.2676
Epoch [1/10], Step [230/782], Loss: 2.2475
Epoch [1/10], Step [