## Model
Definition and demonstration

In [10]:
import torch
import pandas as pd
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import random
import timeit
from tqdm import tqdm

In [11]:
RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 40
LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
NUM_HEADS = 8
DROPOUT = 0.001
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 4
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 16
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 49

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

Fixed code for image with multiple number of channels

In [12]:
# Changed dimensions !!!
from einops import rearrange
from einops.layers.torch import Rearrange
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        # Credit: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py                                                               
        patch_dim = patch_size ** 2 * in_channels
        self.patcher = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, embed_dim),
            nn.LayerNorm(embed_dim)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
        self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)
        # print('cls_token', self.cls_token.shape)
        # print('position_embeddings', self.position_embeddings.shape)

    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1) # (1, 1, embed_dim) --> (Batch_size, 1, embed_dim)
        x = self.patcher(x)
        x = torch.cat([cls_token, x], dim=1) # (B, NUM_PATCHES, embed_dim ) --> (B, NUM_PATCHES + 1, embed_dim )
        # print('x after concat', x.shape)
        x = self.position_embeddings + x # (1, NUM_PATCHES + 1, embed_dim ) + (B, NUM_PATCHES + 1, embed_dim ) --> (B, NUM_PATCHES + 1, embed_dim )
        x = self.dropout(x)
        return x


In [13]:
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape)

torch.Size([512, 50, 16])


Changed architecture for a stronger model.

In [14]:
class ViT(nn.Module):
    def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, dropout, activation, in_channels):
        super().__init__()
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=4*embed_dim, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
        return x

In [15]:
model = ViT(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES

torch.Size([512, 10])




## Train model on a downstream task
Vì tập MNIST quá đơn giản nên em lấy tập CIFAR100.

In [19]:
# prompt: import CIFAR-100 from torch
RANDOM_SEED = 42
BATCH_SIZE = 512
EPOCHS = 40
LEARNING_RATE = 1e-4
NUM_CLASSES = 100  # CIFAR-100 has 100 classes
PATCH_SIZE = 4
IMG_SIZE = 32  # CIFAR-100 images are 32x32
IN_CHANNELS = 3  # CIFAR-100 images have 3 color channels
NUM_HEADS = 12
DROPOUT = 0.1
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 12
# EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 48
EMBED_DIM = 360
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 64

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = ViT(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
# model = ViT(EMBED_DIM, NUM_HEADS, DROPOUT, NUM_ENCODERS, NUM_CLASSES, IN_CHANNELS, PATCH_SIZE, IMG_SIZE).to(device)
x = torch.randn(BATCH_SIZE, 3, 32, 32).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES

torch.Size([512, 100])


In [20]:
patcher = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(512, 3, 32, 32).to(device)
print(patcher(x).shape)

torch.Size([512, 65, 144])


In [21]:
# BTVN:
# viết code huấn luyện VIT trên tập MNIST:
import torch
from torchvision.datasets import CIFAR100

# Load CIFAR-100 dataset
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Split training data into train and validation sets
train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.2, random_state=RANDOM_SEED)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# prompt: train model on train_loader then evaluate on test_loader. using visualization tools like tensorboard or wandb

%pip install wandb -qqq
import wandb

FLAG = 1

# # Initialize wandb
if FLAG:
  wandb.login()
  wandb.init(project="ViT-CIFAR100", entity="letangphuquy4-vietnam-korea-university-of-information-an")  # Replace "your_wandb_username" with your actual wandb username
  wandb.config.update({"learning_rate": LEARNING_RATE, "epochs": EPOCHS, "batch_size": BATCH_SIZE,
                      "num_classes": NUM_CLASSES, "patch_size": PATCH_SIZE, "img_size": IMG_SIZE,
                      "in_channels": IN_CHANNELS, "num_heads": NUM_HEADS, "dropout": DROPOUT,
                      "adam_weight_decay": ADAM_WEIGHT_DECAY, "adam_betas": ADAM_BETAS,
                      "activation": ACTIVATION, "num_encoders": NUM_ENCODERS, "embed_dim": EMBED_DIM,
                      "num_patches": NUM_PATCHES})

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=ADAM_BETAS, weight_decay=ADAM_WEIGHT_DECAY)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Training")):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_total += target.size(0)
        train_correct += (predicted == target).sum().item()

    train_accuracy = 100 * train_correct / train_total
    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += target.size(0)
            val_correct += (predicted == target).sum().item()

    val_accuracy = 100 * val_correct / val_total
    val_loss /= len(val_loader)

    # Log metrics to wandb
    if FLAG:
      wandb.log({"train_loss": train_loss, "train_accuracy": train_accuracy,
                "val_loss": val_loss, "val_accuracy": val_accuracy})

    print(f'Epoch [{epoch+1}/{EPOCHS}], '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

# Evaluation on test set
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        loss = criterion(outputs, target)
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        test_total += target.size(0)
        test_correct += (predicted == target).sum().item()

test_accuracy = 100 * test_correct / test_total
test_loss /= len(test_loader)

print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')

if FLAG:
  # Log test metrics to wandb
  wandb.log({"test_loss": test_loss, "test_accuracy": test_accuracy})
  # Finish the wandb run
  wandb.finish()

Epoch 1/40 - Training: 100%|██████████| 1250/1250 [01:27<00:00, 14.36it/s]


Epoch [1/40], Train Loss: 4.3339, Train Acc: 3.97%, Val Loss: 3.9881, Val Acc: 8.47%


Epoch 2/40 - Training: 100%|██████████| 1250/1250 [01:29<00:00, 14.03it/s]


Epoch [2/40], Train Loss: 3.8305, Train Acc: 10.71%, Val Loss: 3.6522, Val Acc: 13.54%


Epoch 3/40 - Training: 100%|██████████| 1250/1250 [01:30<00:00, 13.88it/s]


Epoch [3/40], Train Loss: 3.5328, Train Acc: 16.25%, Val Loss: 3.4093, Val Acc: 18.17%


Epoch 4/40 - Training: 100%|██████████| 1250/1250 [01:33<00:00, 13.40it/s]


Epoch [4/40], Train Loss: 3.2944, Train Acc: 20.50%, Val Loss: 3.2580, Val Acc: 20.57%


Epoch 5/40 - Training: 100%|██████████| 1250/1250 [01:30<00:00, 13.88it/s]


Epoch [5/40], Train Loss: 3.1057, Train Acc: 23.95%, Val Loss: 3.1137, Val Acc: 23.05%


Epoch 6/40 - Training: 100%|██████████| 1250/1250 [01:29<00:00, 13.93it/s]


Epoch [6/40], Train Loss: 2.9483, Train Acc: 27.10%, Val Loss: 2.9956, Val Acc: 25.95%


Epoch 7/40 - Training: 100%|██████████| 1250/1250 [01:29<00:00, 13.93it/s]


Epoch [7/40], Train Loss: 2.8064, Train Acc: 29.49%, Val Loss: 2.9347, Val Acc: 27.10%


Epoch 8/40 - Training:  99%|█████████▉| 1236/1250 [01:28<00:01, 13.93it/s]