In [69]:
import torch
import torch.nn as nn

# Very primitive example e2e

In [92]:

from os import path

# Image dimensions
pixel_size = 28 # 28x28 image
patch_size = 7  # 7x7 patches (28/7 = 4 patches per side)
num_patches = (pixel_size // patch_size) ** 2  # 4*4 = 16 patches
embed_dim = 128
print(pixel_size, patch_size, embed_dim, num_patches)

def preprocess_data():

    # Randomly generate image tensor
    image = torch.randn(1, pixel_size, pixel_size)  # (batch_size, 28, 28)

    # 1. Patch embedding
    # Reshape image to patches using unfold
    # unfold(dim, size, step) - creates patches of size x size with step stride
    patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    print(f'{patches.shape} - patches') # (batch_size, 7, 7, 4, 4)

    # Reshape to (batch_size, num_patches, patch_size*patch_size)
    patches_flat = patches.reshape(1, num_patches, patch_size * patch_size)
    print(f'{patches_flat.shape} - patches_flat') # (batch_size, 16, 49)

    # 2. Linear Projection to Embed Patches
    patch_embedding = nn.Linear(patch_size * patch_size, embed_dim)
    embedded_patches = patch_embedding(patches_flat)
    print(f'{embedded_patches.shape} - embedded_patches') # (batch, 16, 128)

    # 3. Positional Embedding
    # Add positional embeddings
    position_embeddings = nn.Parameter(torch.randn(1, num_patches, embed_dim))
    embedded_patches_with_pos = embedded_patches + position_embeddings
    print(f'{embedded_patches_with_pos.shape} - embedded_patches_with_pos') # (batch, 16, 128)

    # 4. Class Token Addition
    class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    embedded_patches_with_class = torch.cat([class_token, embedded_patches_with_pos], dim=1)
    print(f'{embedded_patches_with_class.shape} - embedded_patches_with_class') # (batch, 17, 128)

    return embedded_patches_with_class

# Encoder
def encoder(embedded_patches_with_class,embed_dim):
    W_q = nn.Linear(embed_dim, embed_dim)
    W_k = nn.Linear(embed_dim, embed_dim)
    W_v = nn.Linear(embed_dim, embed_dim)
    W_o = nn.Linear(embed_dim, embed_dim)
    print(W_q.weight.shape, W_k.weight.shape, W_v.weight.shape, W_o.weight.shape)

    K = embedded_patches_with_class @ W_k.weight.T # (batch_size, 17, 128)
    Q = embedded_patches_with_class @ W_q.weight.T # (batch_size, 17, 128)
    V = embedded_patches_with_class @ W_v.weight.T # (batch_size, 17, 128)
    A = Q @ K.transpose(-1, -2) # (batch_size, 17, 17)
    A = A / (embed_dim ** 0.5) # scale by sqrt(d_k) to avoid large values
    A = torch.softmax(A, dim=-1)  # Convert to probabilities
    H = A @ V # (batch_size, 17, 128) # Raw attention output
    H  = W_o(H) # (batch_size, 17, 128) # Output of the attention layer

    # print(f'{K.shape} - K')
    # print(f'{Q.shape} - Q')
    # print(f'{V.shape} - V')
    # print(f'{A.shape} - A')
    # print(f'{H.shape} - H')
    
    return H

def classification_head(H):
    # Linear classifier
    Linear_classifier = nn.Linear(embed_dim, 10)

    # Extract only the class token (first token)
    class_token_output = H[:, 0, :]  # Shape: (1, 128)
    print(f"Class token shape: {class_token_output.shape}")

    # Apply classifier only to class token
    logits = Linear_classifier(class_token_output)  # Shape: (1, 10)
    print(f"Logits shape: {logits.shape}")

    # Get prediction
    probability_dist = torch.softmax(logits, dim=1)
    print(f"Probability distribution shape: {probability_dist.shape}")
    print(f"Probability distribution: {probability_dist}")
    predicted_class = torch.argmax(probability_dist, dim=1)
    print(f"Predicted class: {predicted_class}")


embedded_patches_with_class = preprocess_data()
layers = 3
H = embedded_patches_with_class
print(f'{H} - H')
for i in range(layers):
    H = encoder(H,embed_dim)

print(f'{H} - H')

classification_head(H)



28 7 128 16
torch.Size([1, 4, 4, 7, 7]) - patches
torch.Size([1, 16, 49]) - patches_flat
torch.Size([1, 16, 128]) - embedded_patches
torch.Size([1, 16, 128]) - embedded_patches_with_pos
torch.Size([1, 17, 128]) - embedded_patches_with_class
tensor([[[ 0.7384, -0.0144,  0.6405,  ...,  0.4617,  0.3223, -1.2893],
         [ 0.3264,  0.1468, -0.8783,  ..., -0.2934, -2.0359,  1.0929],
         [-2.9512, -1.8823,  0.4007,  ..., -1.7324, -0.0732,  1.3046],
         ...,
         [ 0.0313, -0.0473,  0.4422,  ...,  0.4195, -0.5974, -1.4699],
         [-0.6585,  0.4810, -0.2090,  ...,  0.5938, -0.3574, -0.7519],
         [ 1.6199, -2.7024, -2.0224,  ...,  0.6610, -0.7109, -0.8421]]],
       grad_fn=<CatBackward0>) - H
torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([128, 128])
torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([128, 128])
torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([128, 128])
tensor

# Structuring the code into PyTorch Modules

In [41]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim):
        super().__init__()
        self.patch_size = patch_size # 7
        self.embed_dim = embed_dim # 128
        self.num_patches = (image_size // patch_size) ** 2 # 16
        
        # Define layers and parameters
        self.patch_embedding = nn.Linear(patch_size * patch_size, embed_dim) # 49 -> 128
        self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # 17, 128
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # 1, 1, 128
    
    def forward(self, x):

        # x can be (1, 1, 28, 28) or (1, 1, 40, 40) where batch_size is dynamic (batch_size, channels, image_w, image_h)
        # 1. Patch tokenization
        x = x.squeeze(1) # (1, 1, 28, 28) -> (1, 28, 28) drops the channel dimension
        patches = x.unfold(1, self.patch_size, self.patch_size).unfold(2, self.patch_size, self.patch_size) # (1, 7, 7, 4, 4)

        patches_flat = patches.reshape(x.shape[0], self.num_patches, self.patch_size * self.patch_size) # (1, 16, 49)
        # print(f'{patches_flat.shape} - patches_flat')
        
        # 2. Linear projection (use class attribute)
        embedded_patches = self.patch_embedding(patches_flat) # (1, 1, 16, 128)
        
        # 3. Add positional embeddings (use class attribute)
        embedded_patches_with_pos = embedded_patches + self.position_embedding[:, 1:, :]  # Skip class token position # (1, 1, 17, 128)
        
        # 4. Add class token (use class attribute)
        embedded_patches_with_class = torch.cat([self.class_token.expand(x.shape[0], -1, -1), embedded_patches_with_pos], dim=1) # (1, 1, 17, 128)
        
        return embedded_patches_with_class


image_size = 28
patch_size = 7
embed_dim = 128
batch_size = 1

# Create the module once
patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim)

# Use it multiple times with the same learned weights
image1 = torch.randn(batch_size, image_size, image_size)
image2 = torch.randn(batch_size, image_size, image_size)


patch_embedded_image1 = patch_embedding(image1)  # Uses learned weights
patch_embedded_image1 = patch_embedding(image2)  # Uses same learned weights

print(f'{patch_embedded_image1.shape} - patch_embedded_image1')
print(f'{patch_embedded_image1.shape} - patch_embedded_image1')

torch.Size([1, 17, 128]) - patch_embedded_image1
torch.Size([1, 17, 128]) - patch_embedded_image1


In [42]:
class Encoder(nn.Module):
    def __init__(self, embed_dim):
        super(Encoder, self).__init__()
        self.embed_dim = embed_dim
        
        # Define the linear layers (same as your function)
        self.W_q = nn.Linear(embed_dim, embed_dim) 
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)
        self.mlp_up = nn.Linear(embed_dim, 4*embed_dim) # optional
        self.mlp_down = nn.Linear(4*embed_dim, embed_dim) # optional

    def forward(self, embedded_patches_with_class):
        # Print weight shapes (same as your function)
        # print(self.W_q.weight.shape, self.W_k.weight.shape, self.W_v.weight.shape, self.W_o.weight.shape)

        # Compute Q, K, V (exactly the same logic)
        K = embedded_patches_with_class @ self.W_k.weight.T  # (batch_size, 17, 128)
        Q = embedded_patches_with_class @ self.W_q.weight.T  # (batch_size, 17, 128)
        V = embedded_patches_with_class @ self.W_v.weight.T  # (batch_size, 17, 128)
        
        # Compute attention scores (exactly the same logic)
        A = Q @ K.transpose(-1, -2)  # (batch_size, 17, 17)
        A = A / (self.embed_dim ** 0.5)  # scale by sqrt(d_k) to avoid large values
        A = torch.softmax(A, dim=-1)  # Convert to probabilities
        
        # Apply attention (exactly the same logic)
        H = A @ V  # (batch_size, 17, 128) # Raw attention output
        H = self.W_o(H)  # (batch_size, 17, 128) # Output of the attention layer

        # MLP - optional
        H_residual = self.mlp_up(H)
        H_residual = torch.relu(H_residual)
        H = H + self.mlp_down(H_residual)

        # Uncomment these if you want the same debug prints
        # print(f'{K.shape} - K')
        # print(f'{Q.shape} - Q')
        # print(f'{V.shape} - V')
        # print(f'{A.shape} - A')
        # print(f'{H.shape} - H')
        
        return H
    

# Create encoder once
encoder = Encoder(embed_dim=128)

layers = 3
# Use it multiple times (same weights)
H = patch_embedded_image1
for i in range(layers):
    H = encoder(H)  # Uses the same learned weights each time

print(f'{H} - H')

tensor([[[ 0.0043,  0.0032, -0.0449,  ..., -0.0895,  0.0277, -0.0102],
         [ 0.0043,  0.0032, -0.0449,  ..., -0.0895,  0.0277, -0.0102],
         [ 0.0043,  0.0032, -0.0449,  ..., -0.0895,  0.0277, -0.0102],
         ...,
         [ 0.0043,  0.0032, -0.0449,  ..., -0.0895,  0.0277, -0.0102],
         [ 0.0043,  0.0032, -0.0449,  ..., -0.0895,  0.0277, -0.0102],
         [ 0.0043,  0.0032, -0.0449,  ..., -0.0895,  0.0277, -0.0102]]],
       grad_fn=<AddBackward0>) - H


In [46]:
class ClassificationHead(nn.Module):
    def __init__(self, embed_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.embed_dim = embed_dim
        self.num_classes = num_classes
        
        # Define the linear classifier (same as your function)
        self.Linear_classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, H):
        # Extract only the class token (first token) - same logic
        class_token_output = H[:, 0, :]  # Shape: (batch_size, 128)
        # print(f"Class token shape: {class_token_output.shape}")

        # Apply classifier only to class token - same logic
        logits = self.Linear_classifier(class_token_output)  # Shape: (batch_size, 10)
        # print(f"Logits shape: {logits.shape}")

        # Get prediction - same logic
        probability_dist = torch.softmax(logits, dim=1)
        # print(f"Probability distribution shape: {probability_dist.shape}")
        # print(f"Probability distribution: {probability_dist}")
        # predicted_class = torch.argmax(probability_dist, dim=1)
        # print(f"Predicted class: {predicted_class}")
        
        return probability_dist

# Create classification head
classification_head = ClassificationHead(embed_dim=128, num_classes=10)

# Use it (same as your function)
predicted_class = torch.argmax(classification_head(H), dim=1)  # Uses the same learned weights
print(f'{predicted_class} - predicted_class')

tensor([9]) - predicted_class


In [47]:

class Transformer(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim, num_layers, num_classes):
        super(Transformer, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim)
        
        # Create multiple encoder layers
        self.encoder_layers = nn.ModuleList([
            Encoder(embed_dim=embed_dim) for _ in range(num_layers)
        ])
        
        self.classification_head = ClassificationHead(embed_dim=embed_dim, num_classes=num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        
        # Pass through each encoder layer
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
        
        x = self.classification_head(x)
        return x

image_size = 28
patch_size = 7
embed_dim = 128
batch_size = 5
num_encoder_layers = 3
num_classes = 10

# Create the complete model with 3 layers
model = Transformer(
    image_size=image_size, 
    patch_size=patch_size, 
    embed_dim=embed_dim, 
    num_layers=num_encoder_layers, 
    num_classes=num_classes
)

# Single forward pass handles all layers
images = torch.randn(batch_size, image_size, image_size)  # batch of 4 images
predictions = model(images)  # Automatically goes through all 3 layers
print(f'{predictions} - predictions')

tensor([[0.0946, 0.1023, 0.1049, 0.0923, 0.0909, 0.0905, 0.1124, 0.1001, 0.1096,
         0.1024],
        [0.0945, 0.1022, 0.1044, 0.0919, 0.0910, 0.0906, 0.1126, 0.1001, 0.1101,
         0.1025],
        [0.0945, 0.1023, 0.1045, 0.0923, 0.0908, 0.0907, 0.1123, 0.1002, 0.1097,
         0.1028],
        [0.0948, 0.1019, 0.1047, 0.0921, 0.0910, 0.0909, 0.1120, 0.1001, 0.1100,
         0.1026],
        [0.0942, 0.1021, 0.1044, 0.0920, 0.0905, 0.0908, 0.1122, 0.1004, 0.1107,
         0.1027]], grad_fn=<SoftmaxBackward0>) - predictions


In [48]:
# training a ViT model
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import os

# 1. Load the data
# 2. Define the model
# 3. Define the loss function
# 4. Define the optimizer
# 5. Train the model
# 6. Evaluate the model

# 1. Load the data

def visualize_image(dataset):
    import matplotlib.pyplot as plt

    # Get a sample image
    image, label = dataset[0]

    # Convert to numpy and remove channel dimension
    image_np = image.squeeze().numpy()  # Shape: (28, 28)

    print(f"Image shape: {image_np.shape}")
    print(f"Label: {label}")
    print(f"Pixel values (first 5x5):\n{image_np[:5, :5]}")

    # Visualize the image
    plt.imshow(image_np, cmap='gray')
    plt.title(f'Digit: {label}')
    plt.show()

def get_data_loaders(batch_size=64, data_dir=None):
    """
    Create data loaders for training and testing MNIST dataset.
    
    Args:
        batch_size (int): Batch size for training and testing
        data_dir (str): Directory to store the dataset
        
    Returns:
        tuple: (train_loader, test_loader)
    """
    if data_dir is None:
        data_dir = os.path.join("./data")
    
    # Define transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])
    
    # Create data directory if it doesn't exist
    os.makedirs(data_dir, exist_ok=True)
    
    # Download and load the training data
    train_dataset = datasets.MNIST(
        root=data_dir, 
        train=True, 
        download=True, 
        transform=transform
    )
    
    # Download and load the test data
    test_dataset = datasets.MNIST(
        root=data_dir, 
        train=False, 
        download=True, 
        transform=transform
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=4
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=4
    )
    
    return train_loader, test_loader


train_loader, test_loader = get_data_loaders(batch_size=64, data_dir=None)
print(f'{train_loader} - train_loader')
print(f'{test_loader} - test_loader')

<torch.utils.data.dataloader.DataLoader object at 0x12fa14e80> - train_loader
<torch.utils.data.dataloader.DataLoader object at 0x128353d00> - test_loader


In [49]:
# 2. Define the model
image_size = 28
patch_size = 7
embed_dim = 128
batch_size = 5
num_encoder_layers = 3
num_classes = 10

# Create the complete model with 3 layers
model = Transformer(
    image_size=image_size, 
    patch_size=patch_size, 
    embed_dim=embed_dim, 
    num_layers=num_encoder_layers, 
    num_classes=num_classes
)

In [50]:
# 3. Define the loss function and optimizer
import torch.nn as nn
import torch.optim as optim

lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [51]:
# 4. Train the model
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

def train_model(model, train_loader, test_loader, device, epochs=10, learning_rate=0.001, 
                checkpoint_dir="./checkpoints"):
    """
    Train the MNIST model.
    
    Args:
        model (nn.Module): The neural network model
        train_loader (DataLoader): DataLoader for training data
        test_loader (DataLoader): DataLoader for test data
        device (torch.device): Device to train on (CPU or GPU)
        epochs (int): Number of training epochs
        learning_rate (float): Learning rate for the optimizer
        checkpoint_dir (str): Directory to save model checkpoints
        
    Returns:
        model: Trained model
    """
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    best_accuracy = 0.0
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch_idx, (data, target) in enumerate(progress_bar):
            # Move data to device
            data, target = data.to(device), target.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            output = model(data)
            
            # Calculate loss
            loss = criterion(output, target)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update running loss
            running_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({"loss": running_loss / (batch_idx + 1)})
        
        # Evaluation phase
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            progress_bar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Test]")
            for data, target in progress_bar:
                # Move data to device
                data, target = data.to(device), target.to(device)
                
                # Forward pass
                output = model(data)
                
                # Get predictions
                _, predicted = torch.max(output.data, 1)
                
                # Update counters
                total += target.size(0)
                correct += (predicted == target).sum().item()
                
                # Update progress bar
                accuracy = 100 * correct / total
                progress_bar.set_postfix({"accuracy": accuracy})
        
        # Calculate accuracy
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs} - Test Accuracy: {accuracy:.2f}%")
        
        # Save model if it's the best so far
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            checkpoint_path = os.path.join(checkpoint_dir, "mnist_model.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, checkpoint_path)
            print(f"New best model saved with accuracy: {accuracy:.2f}%")
        
        # Save checkpoint for every epoch
        checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'accuracy': accuracy,
        }, checkpoint_path)
    
    print(f"Training completed. Best accuracy: {best_accuracy:.2f}%")
    return model


# Check for CUDA availability
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_model(model, train_loader, test_loader, device, epochs=10, learning_rate=0.001)

Epoch 1/10 [Train]: 100%|██████████| 938/938 [00:18<00:00, 50.37it/s, loss=2.32]
Epoch 1/10 [Test]: 100%|██████████| 157/157 [00:02<00:00, 54.21it/s, accuracy=12.4] 


Epoch 1/10 - Test Accuracy: 12.37%
New best model saved with accuracy: 12.37%


Epoch 2/10 [Train]: 100%|██████████| 938/938 [00:18<00:00, 50.47it/s, loss=2.34]
Epoch 2/10 [Test]: 100%|██████████| 157/157 [00:02<00:00, 53.04it/s, accuracy=12.4] 


Epoch 2/10 - Test Accuracy: 12.37%


Epoch 3/10 [Train]: 100%|██████████| 938/938 [00:18<00:00, 50.38it/s, loss=2.34]
Epoch 3/10 [Test]: 100%|██████████| 157/157 [00:02<00:00, 53.80it/s, accuracy=12.4] 


Epoch 3/10 - Test Accuracy: 12.37%


Epoch 4/10 [Train]: 100%|██████████| 938/938 [00:19<00:00, 48.07it/s, loss=2.34]
Epoch 4/10 [Test]: 100%|██████████| 157/157 [00:03<00:00, 44.82it/s, accuracy=12.4] 


Epoch 4/10 - Test Accuracy: 12.37%


Epoch 5/10 [Train]: 100%|██████████| 938/938 [00:19<00:00, 48.94it/s, loss=2.34]
Epoch 5/10 [Test]: 100%|██████████| 157/157 [00:02<00:00, 53.59it/s, accuracy=12.4] 


Epoch 5/10 - Test Accuracy: 12.37%


Epoch 6/10 [Train]: 100%|██████████| 938/938 [00:19<00:00, 48.74it/s, loss=2.34]
Epoch 6/10 [Test]: 100%|██████████| 157/157 [00:03<00:00, 50.79it/s, accuracy=12.4] 


Epoch 6/10 - Test Accuracy: 12.37%


Epoch 7/10 [Train]: 100%|██████████| 938/938 [00:19<00:00, 49.15it/s, loss=2.34]
Epoch 7/10 [Test]: 100%|██████████| 157/157 [00:03<00:00, 50.42it/s, accuracy=12.4] 


Epoch 7/10 - Test Accuracy: 12.37%


Epoch 8/10 [Train]: 100%|██████████| 938/938 [00:18<00:00, 49.80it/s, loss=2.34]
Epoch 8/10 [Test]: 100%|██████████| 157/157 [00:02<00:00, 52.88it/s, accuracy=12.4] 


Epoch 8/10 - Test Accuracy: 12.37%


Epoch 9/10 [Train]: 100%|██████████| 938/938 [00:19<00:00, 49.24it/s, loss=2.34]
Epoch 9/10 [Test]: 100%|██████████| 157/157 [00:03<00:00, 51.85it/s, accuracy=12.4] 


Epoch 9/10 - Test Accuracy: 12.37%


Epoch 10/10 [Train]: 100%|██████████| 938/938 [00:19<00:00, 48.95it/s, loss=2.34]
Epoch 10/10 [Test]: 100%|██████████| 157/157 [00:02<00:00, 52.65it/s, accuracy=12.4] 

Epoch 10/10 - Test Accuracy: 12.37%
Training completed. Best accuracy: 12.37%





Transformer(
  (patch_embedding): PatchEmbedding(
    (patch_embedding): Linear(in_features=49, out_features=128, bias=True)
  )
  (encoder_layers): ModuleList(
    (0-2): 3 x Encoder(
      (W_q): Linear(in_features=128, out_features=128, bias=True)
      (W_k): Linear(in_features=128, out_features=128, bias=True)
      (W_v): Linear(in_features=128, out_features=128, bias=True)
      (W_o): Linear(in_features=128, out_features=128, bias=True)
      (mlp_up): Linear(in_features=128, out_features=512, bias=True)
      (mlp_down): Linear(in_features=512, out_features=128, bias=True)
    )
  )
  (classification_head): ClassificationHead(
    (Linear_classifier): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [None]:
# 5. Evaluate the model

# 6. Save the model