## Vision Transformer (ViT)

In this assignment we're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.
The purpose of this homework is for you to get familar with ViT and get prepared for the final project.

In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      super().__init__()
      self.embed_dim = embed_dim
      self.patch_layer = torch.nn.Conv2d(in_channels=in_channels, out_channels=embed_dim,
                                   kernel_size=patch_size, stride=patch_size)
      num_patches = (image_size // patch_size) ** 2
      self.position = nn.Parameter(torch.randn(num_patches+ 1, embed_dim))

    def forward(self, x):
      x = self.patch_layer(x)
      x = x.flatten(2).transpose(1, 2)

      batch_size, _, _ = x.size()
      cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
      cls_token = cls_token.expand(batch_size, -1, -1).to(device)

      x = torch.cat([cls_token, x], dim = 1)
      x += self.position

      return x

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
      super().__init__()
      self.embed_dim = embed_dim
      self.num_heads = num_heads

      self.head_dim = embed_dim // num_heads

      self.Wq = torch.nn.Linear(embed_dim, embed_dim)
      self.Wk = torch.nn.Linear(embed_dim, embed_dim)
      self.Wv = torch.nn.Linear(embed_dim, embed_dim)
      self.Wo = torch.nn.Linear(embed_dim, embed_dim)


    def forward(self, x):
      batch_size, num_patches, embed_dim = x.shape

      queries = self.Wq(x)
      queries = queries.view(batch_size, num_patches, self.num_heads,
                             self.head_dim).permute(0, 2, 3, 1)

      keys = self.Wk(x)
      keys = keys.view(batch_size, num_patches, self.num_heads,
                       self.head_dim).permute(0, 2, 3, 1)

      values = self.Wv(x)
      values = values.view(batch_size, num_patches, self.num_heads,
                           self.head_dim).permute(0, 2, 3, 1)

      attn_scores = (queries @ keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
      attn_weights = F.softmax(attn_scores, dim=-1)

      output = attn_weights @ values
      output = output.permute(0, 3, 1, 2)
      output = output.contiguous().view(batch_size, num_patches, embed_dim)
      output = self.Wo(output)

      return output

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [None]:
class MLP(nn.Module):
    def __init__(self, embed_dim, mlp_dim, dropout):
        super().__init__()
        self.layer1 = nn.Linear(embed_dim, mlp_dim)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(mlp_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer1(x)
        x = self.gelu(x)
        x = self.layer2(x)
        x = self.dropout(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
      super().__init__()
      self.attention = MultiHeadSelfAttention(embed_dim=embed_dim,
                                         num_heads=num_heads)
      self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
      self.mlp = MLP(embed_dim, mlp_dim, dropout=dropout)
      self.layer_norm2 = torch.nn.LayerNorm(embed_dim)

    def forward(self, x):
      attention = self.attention(self.layer_norm1(x))
      x = x + attention

      mlp_out = self.mlp(self.layer_norm2(x))
      x = x + mlp_out

      return x

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
      super().__init__()
      self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
      self.transformers = nn.ModuleList([
          TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
          for _ in range(num_layers)
      ])
      self.classifiers = torch.nn.Linear(embed_dim, num_classes)

    def forward(self, x):
      embedded = self.patch_embedding(x)
      for transformer in self.transformers:
        embedded = transformer(embedded)

      encoded = embedded.mean(dim=1)
      classified = self.classifiers(encoded)
      return classified

## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [35]:
# Example usage:
image_size = 32 # TODO
patch_size = 4 # TODO
in_channels = 3 # TODO
embed_dim = 48 # TODO
num_heads = 4 # TODO
mlp_dim = 192 # TODO
num_layers = 4 # TODO
num_classes = 100 # TODO
dropout = 0.1 # TODO

batch_size = 128 # TODO

In [36]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 100])


In [37]:
# Load the CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [38]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)

In [39]:
# Train the model
from tqdm import tqdm

num_epochs = 150 # TODO
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # TODO Feel free to modify the training loop youself.

    # Validate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total
    print(f"Epoch: {epoch + 1}, Validation Accuracy: {val_acc:.2f}%")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

Epoch 1/150: 100%|██████████| 391/391 [00:12<00:00, 31.99it/s]


Epoch: 1, Validation Accuracy: 12.79%


Epoch 2/150: 100%|██████████| 391/391 [00:12<00:00, 31.71it/s]


Epoch: 2, Validation Accuracy: 15.79%


Epoch 3/150: 100%|██████████| 391/391 [00:12<00:00, 32.52it/s]


Epoch: 3, Validation Accuracy: 21.13%


Epoch 4/150: 100%|██████████| 391/391 [00:12<00:00, 30.45it/s]


Epoch: 4, Validation Accuracy: 24.48%


Epoch 5/150: 100%|██████████| 391/391 [00:12<00:00, 30.70it/s]


Epoch: 5, Validation Accuracy: 27.66%


Epoch 6/150: 100%|██████████| 391/391 [00:12<00:00, 30.22it/s]


Epoch: 6, Validation Accuracy: 27.60%


Epoch 7/150: 100%|██████████| 391/391 [00:11<00:00, 32.67it/s]


Epoch: 7, Validation Accuracy: 31.23%


Epoch 8/150: 100%|██████████| 391/391 [00:12<00:00, 30.96it/s]


Epoch: 8, Validation Accuracy: 32.26%


Epoch 9/150: 100%|██████████| 391/391 [00:12<00:00, 31.10it/s]


Epoch: 9, Validation Accuracy: 33.32%


Epoch 10/150: 100%|██████████| 391/391 [00:12<00:00, 31.89it/s]


Epoch: 10, Validation Accuracy: 33.94%


Epoch 11/150: 100%|██████████| 391/391 [00:12<00:00, 31.01it/s]


Epoch: 11, Validation Accuracy: 35.14%


Epoch 12/150: 100%|██████████| 391/391 [00:12<00:00, 30.73it/s]


Epoch: 12, Validation Accuracy: 35.03%


Epoch 13/150: 100%|██████████| 391/391 [00:12<00:00, 30.71it/s]


Epoch: 13, Validation Accuracy: 35.75%


Epoch 14/150: 100%|██████████| 391/391 [00:13<00:00, 29.11it/s]


Epoch: 14, Validation Accuracy: 37.16%


Epoch 15/150: 100%|██████████| 391/391 [00:12<00:00, 30.43it/s]


Epoch: 15, Validation Accuracy: 37.39%


Epoch 16/150: 100%|██████████| 391/391 [00:13<00:00, 29.83it/s]


Epoch: 16, Validation Accuracy: 36.47%


Epoch 17/150: 100%|██████████| 391/391 [00:12<00:00, 31.81it/s]


Epoch: 17, Validation Accuracy: 35.84%


Epoch 18/150: 100%|██████████| 391/391 [00:12<00:00, 30.23it/s]


Epoch: 18, Validation Accuracy: 38.09%


Epoch 19/150: 100%|██████████| 391/391 [00:13<00:00, 29.19it/s]


Epoch: 19, Validation Accuracy: 39.36%


Epoch 20/150: 100%|██████████| 391/391 [00:13<00:00, 29.01it/s]


Epoch: 20, Validation Accuracy: 39.41%


Epoch 21/150: 100%|██████████| 391/391 [00:13<00:00, 29.98it/s]


Epoch: 21, Validation Accuracy: 40.10%


Epoch 22/150: 100%|██████████| 391/391 [00:13<00:00, 29.37it/s]


Epoch: 22, Validation Accuracy: 40.30%


Epoch 23/150: 100%|██████████| 391/391 [00:13<00:00, 29.80it/s]


Epoch: 23, Validation Accuracy: 40.86%


Epoch 24/150: 100%|██████████| 391/391 [00:13<00:00, 29.57it/s]


Epoch: 24, Validation Accuracy: 40.35%


Epoch 25/150: 100%|██████████| 391/391 [00:13<00:00, 28.99it/s]


Epoch: 25, Validation Accuracy: 42.41%


Epoch 26/150: 100%|██████████| 391/391 [00:13<00:00, 28.56it/s]


Epoch: 26, Validation Accuracy: 41.16%


Epoch 27/150: 100%|██████████| 391/391 [00:13<00:00, 29.51it/s]


Epoch: 27, Validation Accuracy: 40.94%


Epoch 28/150: 100%|██████████| 391/391 [00:13<00:00, 28.65it/s]


Epoch: 28, Validation Accuracy: 42.77%


Epoch 29/150: 100%|██████████| 391/391 [00:13<00:00, 28.43it/s]


Epoch: 29, Validation Accuracy: 41.99%


Epoch 30/150: 100%|██████████| 391/391 [00:13<00:00, 28.81it/s]


Epoch: 30, Validation Accuracy: 42.31%


Epoch 31/150: 100%|██████████| 391/391 [00:13<00:00, 29.43it/s]


Epoch: 31, Validation Accuracy: 42.75%


Epoch 32/150: 100%|██████████| 391/391 [00:13<00:00, 28.50it/s]


Epoch: 32, Validation Accuracy: 43.71%


Epoch 33/150: 100%|██████████| 391/391 [00:13<00:00, 28.88it/s]


Epoch: 33, Validation Accuracy: 43.19%


Epoch 34/150: 100%|██████████| 391/391 [00:13<00:00, 28.26it/s]


Epoch: 34, Validation Accuracy: 44.30%


Epoch 35/150: 100%|██████████| 391/391 [00:13<00:00, 28.89it/s]


Epoch: 35, Validation Accuracy: 44.62%


Epoch 36/150: 100%|██████████| 391/391 [00:13<00:00, 29.22it/s]


Epoch: 36, Validation Accuracy: 44.10%


Epoch 37/150: 100%|██████████| 391/391 [00:13<00:00, 29.43it/s]


Epoch: 37, Validation Accuracy: 44.13%


Epoch 38/150: 100%|██████████| 391/391 [00:13<00:00, 29.53it/s]


Epoch: 38, Validation Accuracy: 43.47%


Epoch 39/150: 100%|██████████| 391/391 [00:13<00:00, 28.74it/s]


Epoch: 39, Validation Accuracy: 44.86%


Epoch 40/150: 100%|██████████| 391/391 [00:13<00:00, 29.00it/s]


Epoch: 40, Validation Accuracy: 44.61%


Epoch 41/150: 100%|██████████| 391/391 [00:13<00:00, 28.02it/s]


Epoch: 41, Validation Accuracy: 43.83%


Epoch 42/150: 100%|██████████| 391/391 [00:13<00:00, 29.07it/s]


Epoch: 42, Validation Accuracy: 44.35%


Epoch 43/150: 100%|██████████| 391/391 [00:12<00:00, 32.13it/s]


Epoch: 43, Validation Accuracy: 44.81%


Epoch 44/150: 100%|██████████| 391/391 [00:12<00:00, 31.69it/s]


Epoch: 44, Validation Accuracy: 44.87%


Epoch 45/150: 100%|██████████| 391/391 [00:12<00:00, 30.28it/s]


Epoch: 45, Validation Accuracy: 45.09%


Epoch 46/150: 100%|██████████| 391/391 [00:13<00:00, 29.97it/s]


Epoch: 46, Validation Accuracy: 45.28%


Epoch 47/150: 100%|██████████| 391/391 [00:12<00:00, 30.61it/s]


Epoch: 47, Validation Accuracy: 44.98%


Epoch 48/150: 100%|██████████| 391/391 [00:12<00:00, 31.46it/s]


Epoch: 48, Validation Accuracy: 45.42%


Epoch 49/150: 100%|██████████| 391/391 [00:12<00:00, 31.53it/s]


Epoch: 49, Validation Accuracy: 45.63%


Epoch 50/150: 100%|██████████| 391/391 [00:12<00:00, 31.13it/s]


Epoch: 50, Validation Accuracy: 45.50%


Epoch 51/150: 100%|██████████| 391/391 [00:12<00:00, 30.89it/s]


Epoch: 51, Validation Accuracy: 44.84%


Epoch 52/150: 100%|██████████| 391/391 [00:12<00:00, 31.52it/s]


Epoch: 52, Validation Accuracy: 44.67%


Epoch 53/150: 100%|██████████| 391/391 [00:13<00:00, 29.40it/s]


Epoch: 53, Validation Accuracy: 44.49%


Epoch 54/150: 100%|██████████| 391/391 [00:13<00:00, 29.27it/s]


Epoch: 54, Validation Accuracy: 45.98%


Epoch 55/150: 100%|██████████| 391/391 [00:13<00:00, 30.04it/s]


Epoch: 55, Validation Accuracy: 45.44%


Epoch 56/150: 100%|██████████| 391/391 [00:12<00:00, 31.90it/s]


Epoch: 56, Validation Accuracy: 46.78%


Epoch 57/150: 100%|██████████| 391/391 [00:13<00:00, 29.69it/s]


Epoch: 57, Validation Accuracy: 46.75%


Epoch 58/150: 100%|██████████| 391/391 [00:13<00:00, 29.98it/s]


Epoch: 58, Validation Accuracy: 47.00%


Epoch 59/150: 100%|██████████| 391/391 [00:12<00:00, 32.42it/s]


Epoch: 59, Validation Accuracy: 45.80%


Epoch 60/150: 100%|██████████| 391/391 [00:11<00:00, 32.77it/s]


Epoch: 60, Validation Accuracy: 47.43%


Epoch 61/150: 100%|██████████| 391/391 [00:12<00:00, 31.36it/s]


Epoch: 61, Validation Accuracy: 46.96%


Epoch 62/150: 100%|██████████| 391/391 [00:11<00:00, 33.71it/s]


Epoch: 62, Validation Accuracy: 45.65%


Epoch 63/150: 100%|██████████| 391/391 [00:12<00:00, 31.17it/s]


Epoch: 63, Validation Accuracy: 46.46%


Epoch 64/150: 100%|██████████| 391/391 [00:12<00:00, 31.69it/s]


Epoch: 64, Validation Accuracy: 46.62%


Epoch 65/150: 100%|██████████| 391/391 [00:12<00:00, 32.16it/s]


Epoch: 65, Validation Accuracy: 48.09%


Epoch 66/150: 100%|██████████| 391/391 [00:12<00:00, 32.49it/s]


Epoch: 66, Validation Accuracy: 46.96%


Epoch 67/150: 100%|██████████| 391/391 [00:12<00:00, 32.58it/s]


Epoch: 67, Validation Accuracy: 45.82%


Epoch 68/150: 100%|██████████| 391/391 [00:12<00:00, 31.84it/s]


Epoch: 68, Validation Accuracy: 47.60%


Epoch 69/150: 100%|██████████| 391/391 [00:12<00:00, 30.85it/s]


Epoch: 69, Validation Accuracy: 47.38%


Epoch 70/150: 100%|██████████| 391/391 [00:12<00:00, 31.15it/s]


Epoch: 70, Validation Accuracy: 48.72%


Epoch 71/150: 100%|██████████| 391/391 [00:12<00:00, 32.16it/s]


Epoch: 71, Validation Accuracy: 47.93%


Epoch 72/150: 100%|██████████| 391/391 [00:12<00:00, 31.85it/s]


Epoch: 72, Validation Accuracy: 48.09%


Epoch 73/150: 100%|██████████| 391/391 [00:12<00:00, 30.82it/s]


Epoch: 73, Validation Accuracy: 47.40%


Epoch 74/150: 100%|██████████| 391/391 [00:12<00:00, 31.16it/s]


Epoch: 74, Validation Accuracy: 45.98%


Epoch 75/150: 100%|██████████| 391/391 [00:13<00:00, 29.97it/s]


Epoch: 75, Validation Accuracy: 48.38%


Epoch 76/150: 100%|██████████| 391/391 [00:12<00:00, 31.03it/s]


Epoch: 76, Validation Accuracy: 46.15%


Epoch 77/150: 100%|██████████| 391/391 [00:12<00:00, 30.78it/s]


Epoch: 77, Validation Accuracy: 46.78%


Epoch 78/150: 100%|██████████| 391/391 [00:12<00:00, 31.16it/s]


Epoch: 78, Validation Accuracy: 46.37%


Epoch 79/150: 100%|██████████| 391/391 [00:12<00:00, 30.35it/s]


Epoch: 79, Validation Accuracy: 48.00%


Epoch 80/150: 100%|██████████| 391/391 [00:12<00:00, 30.28it/s]


Epoch: 80, Validation Accuracy: 48.08%


Epoch 81/150: 100%|██████████| 391/391 [00:12<00:00, 30.69it/s]


Epoch: 81, Validation Accuracy: 47.66%


Epoch 82/150: 100%|██████████| 391/391 [00:12<00:00, 30.25it/s]


Epoch: 82, Validation Accuracy: 48.04%


Epoch 83/150: 100%|██████████| 391/391 [00:12<00:00, 31.63it/s]


Epoch: 83, Validation Accuracy: 46.96%


Epoch 84/150: 100%|██████████| 391/391 [00:11<00:00, 32.68it/s]


Epoch: 84, Validation Accuracy: 47.35%


Epoch 85/150: 100%|██████████| 391/391 [00:12<00:00, 30.76it/s]


Epoch: 85, Validation Accuracy: 47.02%


Epoch 86/150: 100%|██████████| 391/391 [00:12<00:00, 32.18it/s]


Epoch: 86, Validation Accuracy: 47.41%


Epoch 87/150: 100%|██████████| 391/391 [00:12<00:00, 32.52it/s]


Epoch: 87, Validation Accuracy: 48.22%


Epoch 88/150: 100%|██████████| 391/391 [00:12<00:00, 32.44it/s]


Epoch: 88, Validation Accuracy: 46.59%


Epoch 89/150: 100%|██████████| 391/391 [00:11<00:00, 32.96it/s]


Epoch: 89, Validation Accuracy: 47.97%


Epoch 90/150: 100%|██████████| 391/391 [00:12<00:00, 31.10it/s]


Epoch: 90, Validation Accuracy: 47.13%


Epoch 91/150: 100%|██████████| 391/391 [00:12<00:00, 32.10it/s]


Epoch: 91, Validation Accuracy: 48.53%


Epoch 92/150: 100%|██████████| 391/391 [00:12<00:00, 32.58it/s]


Epoch: 92, Validation Accuracy: 47.32%


Epoch 93/150: 100%|██████████| 391/391 [00:12<00:00, 31.90it/s]


Epoch: 93, Validation Accuracy: 46.74%


Epoch 94/150: 100%|██████████| 391/391 [00:11<00:00, 33.34it/s]


Epoch: 94, Validation Accuracy: 48.35%


Epoch 95/150: 100%|██████████| 391/391 [00:12<00:00, 32.37it/s]


Epoch: 95, Validation Accuracy: 49.00%


Epoch 96/150: 100%|██████████| 391/391 [00:11<00:00, 33.67it/s]


Epoch: 96, Validation Accuracy: 49.00%


Epoch 97/150: 100%|██████████| 391/391 [00:11<00:00, 32.74it/s]


Epoch: 97, Validation Accuracy: 48.48%


Epoch 98/150: 100%|██████████| 391/391 [00:11<00:00, 32.73it/s]


Epoch: 98, Validation Accuracy: 48.07%


Epoch 99/150: 100%|██████████| 391/391 [00:12<00:00, 32.36it/s]


Epoch: 99, Validation Accuracy: 47.97%


Epoch 100/150: 100%|██████████| 391/391 [00:11<00:00, 32.96it/s]


Epoch: 100, Validation Accuracy: 48.59%


Epoch 101/150: 100%|██████████| 391/391 [00:12<00:00, 31.74it/s]


Epoch: 101, Validation Accuracy: 48.48%


Epoch 102/150: 100%|██████████| 391/391 [00:12<00:00, 32.05it/s]


Epoch: 102, Validation Accuracy: 48.95%


Epoch 103/150: 100%|██████████| 391/391 [00:12<00:00, 31.73it/s]


Epoch: 103, Validation Accuracy: 48.87%


Epoch 104/150: 100%|██████████| 391/391 [00:12<00:00, 31.06it/s]


Epoch: 104, Validation Accuracy: 47.53%


Epoch 105/150: 100%|██████████| 391/391 [00:12<00:00, 32.30it/s]


Epoch: 105, Validation Accuracy: 49.46%


Epoch 106/150: 100%|██████████| 391/391 [00:11<00:00, 33.98it/s]


Epoch: 106, Validation Accuracy: 48.53%


Epoch 107/150: 100%|██████████| 391/391 [00:11<00:00, 33.24it/s]


Epoch: 107, Validation Accuracy: 48.17%


Epoch 108/150: 100%|██████████| 391/391 [00:11<00:00, 33.81it/s]


Epoch: 108, Validation Accuracy: 47.89%


Epoch 109/150: 100%|██████████| 391/391 [00:11<00:00, 33.39it/s]


Epoch: 109, Validation Accuracy: 49.36%


Epoch 110/150: 100%|██████████| 391/391 [00:12<00:00, 32.46it/s]


Epoch: 110, Validation Accuracy: 48.11%


Epoch 111/150: 100%|██████████| 391/391 [00:11<00:00, 32.75it/s]


Epoch: 111, Validation Accuracy: 49.21%


Epoch 112/150: 100%|██████████| 391/391 [00:11<00:00, 33.92it/s]


Epoch: 112, Validation Accuracy: 48.03%


Epoch 113/150: 100%|██████████| 391/391 [00:12<00:00, 31.66it/s]


Epoch: 113, Validation Accuracy: 49.53%


Epoch 114/150: 100%|██████████| 391/391 [00:11<00:00, 33.24it/s]


Epoch: 114, Validation Accuracy: 48.59%


Epoch 115/150: 100%|██████████| 391/391 [00:11<00:00, 33.17it/s]


Epoch: 115, Validation Accuracy: 49.77%


Epoch 116/150: 100%|██████████| 391/391 [00:12<00:00, 32.24it/s]


Epoch: 116, Validation Accuracy: 48.12%


Epoch 117/150: 100%|██████████| 391/391 [00:12<00:00, 32.40it/s]


Epoch: 117, Validation Accuracy: 48.09%


Epoch 118/150: 100%|██████████| 391/391 [00:12<00:00, 30.56it/s]


Epoch: 118, Validation Accuracy: 47.99%


Epoch 119/150: 100%|██████████| 391/391 [00:11<00:00, 32.89it/s]


Epoch: 119, Validation Accuracy: 49.13%


Epoch 120/150: 100%|██████████| 391/391 [00:12<00:00, 30.55it/s]


Epoch: 120, Validation Accuracy: 48.31%


Epoch 121/150: 100%|██████████| 391/391 [00:12<00:00, 32.45it/s]


Epoch: 121, Validation Accuracy: 49.23%


Epoch 122/150: 100%|██████████| 391/391 [00:11<00:00, 32.68it/s]


Epoch: 122, Validation Accuracy: 49.10%


Epoch 123/150: 100%|██████████| 391/391 [00:12<00:00, 32.46it/s]


Epoch: 123, Validation Accuracy: 50.03%


Epoch 124/150: 100%|██████████| 391/391 [00:11<00:00, 32.83it/s]


Epoch: 124, Validation Accuracy: 48.58%


Epoch 125/150: 100%|██████████| 391/391 [00:11<00:00, 32.89it/s]


Epoch: 125, Validation Accuracy: 49.47%


Epoch 126/150: 100%|██████████| 391/391 [00:11<00:00, 33.35it/s]


Epoch: 126, Validation Accuracy: 48.62%


Epoch 127/150: 100%|██████████| 391/391 [00:11<00:00, 33.90it/s]


Epoch: 127, Validation Accuracy: 48.88%


Epoch 128/150: 100%|██████████| 391/391 [00:12<00:00, 32.54it/s]


Epoch: 128, Validation Accuracy: 49.30%


Epoch 129/150: 100%|██████████| 391/391 [00:11<00:00, 33.84it/s]


Epoch: 129, Validation Accuracy: 49.61%


Epoch 130/150: 100%|██████████| 391/391 [00:11<00:00, 32.82it/s]


Epoch: 130, Validation Accuracy: 49.71%


Epoch 131/150: 100%|██████████| 391/391 [00:11<00:00, 33.34it/s]


Epoch: 131, Validation Accuracy: 48.79%


Epoch 132/150: 100%|██████████| 391/391 [00:11<00:00, 33.69it/s]


Epoch: 132, Validation Accuracy: 49.09%


Epoch 133/150: 100%|██████████| 391/391 [00:11<00:00, 33.04it/s]


Epoch: 133, Validation Accuracy: 49.19%


Epoch 134/150: 100%|██████████| 391/391 [00:11<00:00, 33.40it/s]


Epoch: 134, Validation Accuracy: 49.22%


Epoch 135/150: 100%|██████████| 391/391 [00:12<00:00, 31.70it/s]


Epoch: 135, Validation Accuracy: 48.41%


Epoch 136/150: 100%|██████████| 391/391 [00:12<00:00, 31.86it/s]


Epoch: 136, Validation Accuracy: 48.91%


Epoch 137/150: 100%|██████████| 391/391 [00:12<00:00, 31.00it/s]


Epoch: 137, Validation Accuracy: 49.66%


Epoch 138/150: 100%|██████████| 391/391 [00:12<00:00, 30.54it/s]


Epoch: 138, Validation Accuracy: 50.52%


Epoch 139/150: 100%|██████████| 391/391 [00:11<00:00, 32.67it/s]


Epoch: 139, Validation Accuracy: 49.37%


Epoch 140/150: 100%|██████████| 391/391 [00:12<00:00, 30.99it/s]


Epoch: 140, Validation Accuracy: 49.68%


Epoch 141/150: 100%|██████████| 391/391 [00:13<00:00, 29.02it/s]


Epoch: 141, Validation Accuracy: 48.72%


Epoch 142/150: 100%|██████████| 391/391 [00:13<00:00, 29.41it/s]


Epoch: 142, Validation Accuracy: 48.51%


Epoch 143/150: 100%|██████████| 391/391 [00:13<00:00, 29.64it/s]


Epoch: 143, Validation Accuracy: 49.94%


Epoch 144/150: 100%|██████████| 391/391 [00:13<00:00, 29.15it/s]


Epoch: 144, Validation Accuracy: 49.78%


Epoch 145/150: 100%|██████████| 391/391 [00:13<00:00, 29.75it/s]


Epoch: 145, Validation Accuracy: 50.57%


Epoch 146/150: 100%|██████████| 391/391 [00:13<00:00, 29.37it/s]


Epoch: 146, Validation Accuracy: 48.31%


Epoch 147/150: 100%|██████████| 391/391 [00:12<00:00, 30.13it/s]


Epoch: 147, Validation Accuracy: 48.89%


Epoch 148/150: 100%|██████████| 391/391 [00:12<00:00, 30.57it/s]


Epoch: 148, Validation Accuracy: 49.09%


Epoch 149/150: 100%|██████████| 391/391 [00:13<00:00, 30.03it/s]


Epoch: 149, Validation Accuracy: 51.31%


Epoch 150/150: 100%|██████████| 391/391 [00:13<00:00, 29.28it/s]


Epoch: 150, Validation Accuracy: 47.74%


Please submit your best_model.pth with this notebook. And report the best test results you get: Maximum accuracy is 51.31% if we utilize early-stopping.