# Training TinyViT

Let's train TinyViT!

In [1]:
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data

import numpy as np
import os
import time

from import_shelf import shelf
from shelf.models.transformer import VisionTransformer
from shelf.dataloaders.cifar import get_CIFAR10_dataset
from shelf.trainers import train, validate

os.environ["CUDA_VISIBLE_DEVICES"]="MIG-60fed909-9539-55f4-9bab-e99df995d4a0"


In [2]:
### HYPERPARAMS ###

EPOCHS = 500
BATCH_SIZE = 512
LEARNING_RATE = 1e-4

IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM_HIDDEN = 256
DEPTH = 4
NUM_HEADS = 6
DIM_MLP = 256
DROPOUT = 0.1
EMB_DROPOUT = 0.1

NUM_CLASSES = 10

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PATH_MODEL = './saves/train_tinyvit/model.pth'


In [3]:
### DATA LOADING ###

train_loader, val_loader = get_CIFAR10_dataset(batch_size=BATCH_SIZE)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified


In [4]:
### MODEL ###

model = VisionTransformer(
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    dim=DIM_HIDDEN,
    depth=DEPTH,
    heads=NUM_HEADS,
    mlp_dim=DIM_MLP,
    dropout=0.1,
    emb_dropout=0.1,
    num_classes=NUM_CLASSES,
).to(DEVICE)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of parameters: {num_params}")


Number of parameters: 2136842


In [5]:
### OTHERS ###

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=5e-6)


In [6]:

### TRAINING ###

train_losses = []
train_accuracies = []

val_losses = []
val_accuracies = []

for epoch in range(EPOCHS):
    start_time = time.time()
    
    train_acc, train_loss = train(train_loader, model, criterion, optimizer, epoch, verbose=False)
    val_acc, val_loss = validate(val_loader, model, criterion, epoch, verbose=False)

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(
        f"Epoch {epoch+1:3d}/{EPOCHS}, "
        f"LR: {scheduler.get_last_lr()[0]:.4e} | "
        f"Train Loss: {train_loss:.4f}, "
        f"Train Acc: {train_acc * 100:.2f}%, "
        f"Val Loss: {val_loss:.4f}, "
        f"Val Acc: {val_acc*100:.2f}% | "
        f"Time: {time.time() - start_time:.3f}s"
    )
    
    scheduler.step()


torch.save(model.state_dict(), PATH_MODEL)

print(f"Model saved to {PATH_MODEL}")

Epoch   1/500, LR: 1.0000e-04 | Train Loss: 2.0698, Train Acc: 23.02%, Val Loss: 1.8503, Val Acc: 32.87% | Time: 15.901s
Epoch   2/500, LR: 9.9999e-05 | Train Loss: 1.8066, Train Acc: 33.95%, Val Loss: 1.6350, Val Acc: 41.69% | Time: 15.044s
Epoch   3/500, LR: 9.9996e-05 | Train Loss: 1.6628, Train Acc: 39.23%, Val Loss: 1.5191, Val Acc: 46.24% | Time: 15.057s
Epoch   4/500, LR: 9.9992e-05 | Train Loss: 1.5818, Train Acc: 42.60%, Val Loss: 1.4654, Val Acc: 48.09% | Time: 15.228s
Epoch   5/500, LR: 9.9985e-05 | Train Loss: 1.5301, Train Acc: 44.46%, Val Loss: 1.4233, Val Acc: 49.60% | Time: 14.967s
Epoch   6/500, LR: 9.9977e-05 | Train Loss: 1.4921, Train Acc: 45.90%, Val Loss: 1.3788, Val Acc: 51.38% | Time: 17.130s
Epoch   7/500, LR: 9.9966e-05 | Train Loss: 1.4644, Train Acc: 46.95%, Val Loss: 1.3669, Val Acc: 51.49% | Time: 19.047s
Epoch   8/500, LR: 9.9954e-05 | Train Loss: 1.4343, Train Acc: 48.08%, Val Loss: 1.3170, Val Acc: 53.18% | Time: 18.905s
Epoch   9/500, LR: 9.9940e-05 | 

In [7]:
### VALIDATION ###

model_state_dict = torch.load(PATH_MODEL)
model.load_state_dict(model_state_dict)

val_acc, val_loss = validate(val_loader, model, criterion, EPOCHS)

print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc*100:.2f}%")

                                                                                               

Validation Loss: 0.7509, Validation Accuracy: 78.19%


