In [1]:
from torchvision import datasets
import torch
import torch.nn as nn
import torchvision.transforms.v2 as transforms
import matplotlib.pyplot as plt
import vit
import numpy as np

%matplotlib inline

In [None]:
device = torch.device('cuda:1')

In [2]:
IMAGE_SIZE = 32
CHANNELS = 3

PATCH_SIZE = 16

D_MODEL = 256
NUM_HEADS = 8
NUM_LAYERS = 4

BATCH_SIZE = 256
NUM_STEPS = 5000
BATCHES_PER_STEP = 1
EVAL_EVERY = 1000
BASE_LR = 3e-3
WEIGHT_DECAY = 0.03

In [47]:
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Lambda(lambda x: vit.image_to_patches(x, PATCH_SIZE).squeeze())
])

train_data = datasets.CIFAR100(root='data', download=False, train=True, transform=transform)
test_data = datasets.CIFAR100(root='data', download=False, train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_data, BATCH_SIZE)

In [4]:
model = vit.VisionTransformer(
    d_model = D_MODEL,
    image_size = IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    channels=CHANNELS,
    num_heads=NUM_HEADS,
    d_ffn=4*D_MODEL,
    num_encoder_layer=NUM_LAYERS,
    num_classes=100,
    p_dropout=0.1,
    conv_embedding=True,
    torch_attention=True
)

model.get_wd_params()

({'encoder.0.ffn_layer.model.0.weight',
  'encoder.0.ffn_layer.model.3.weight',
  'encoder.0.self_attention.in_proj_weight',
  'encoder.0.self_attention.out_proj.weight',
  'encoder.1.ffn_layer.model.0.weight',
  'encoder.1.ffn_layer.model.3.weight',
  'encoder.1.self_attention.in_proj_weight',
  'encoder.1.self_attention.out_proj.weight',
  'encoder.2.ffn_layer.model.0.weight',
  'encoder.2.ffn_layer.model.3.weight',
  'encoder.2.self_attention.in_proj_weight',
  'encoder.2.self_attention.out_proj.weight',
  'encoder.3.ffn_layer.model.0.weight',
  'encoder.3.ffn_layer.model.3.weight',
  'encoder.3.self_attention.in_proj_weight',
  'encoder.3.self_attention.out_proj.weight',
  'mlp_head.weight',
  'patch_embedding.conv.weight'},
 {'class_token',
  'encoder.0.ffn_layer.model.0.bias',
  'encoder.0.ffn_layer.model.3.bias',
  'encoder.0.ffn_norm.bias',
  'encoder.0.ffn_norm.weight',
  'encoder.0.self_att_norm.bias',
  'encoder.0.self_att_norm.weight',
  'encoder.0.self_attention.in_proj_bi

In [None]:

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), BASE_LR, betas=(0.9, 0.999), weight_decay=WEIGHT_DECAY)

lr_scheduler = None

In [None]:
def calc_train_loss(batch: list[torch.Tensor]) -> torch.Tensor:
    model.train()
    img, label = batch
    img, label = img.to(device), label.to(device)
    
    pred = model(img)
    loss = criterion(pred, label)
    return loss

def eval_model() -> float:
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for input, target in test_loader:
            input = input.to(device)
            target = target.to(device)

            predicted = model.predict(input)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        accuracy = correct / total
    return accuracy

loss_curve = []
accuracy_curve = []

num_epochs = (NUM_STEPS // (len(train_loader) // BATCHES_PER_STEP)) + 1

print('=========================================')
print('Starting Traning')
print(f'  -Using device: {device}')
print(f'  -Number of parameters: {np.sum([p.numel() for p in model.parameters()])}')
print(f'  -Number of steps: {NUM_STEPS}')
print(f'  -Number of epochs: {num_epochs}')
print(f'  -Batch size: {train_loader.batch_size}')
print('=========================================')
print()


step = 0
loss_history = []
optimizer.zero_grad()
for epoch in range(num_epochs):
    for batch_index, batch in enumerate(train_loader):
        
        loss = calc_train_loss(batch)
        
        (loss / BATCHES_PER_STEP).backward()

        loss_history.append(loss.detach().clone())

        if (batch_index + 1) % BATCHES_PER_STEP == 0:
            # Perform step
            step += 1
            optimizer.step()
            optimizer.zero_grad()
            if lr_scheduler is not None:
                lr_scheduler.step()

            if (step + 1) % 100 == 0:
                # Track train loss every few steps
                average_loss = torch.stack(loss_history).mean().item()
                loss_curve.append(average_loss)
                loss_history = []
                print(f'Epoch: {epoch + 1:3d}/{num_epochs:3d}, Step {step + 1:5d}/{NUM_STEPS}, Loss: {average_loss:.4f}')

            # After eval_every steps, calc accuracy
            if step % EVAL_EVERY == 0:
                accuracy = eval_model()
                accuracy_curve.append(accuracy)
                print('=========================================')
                print(f'Epoch: {epoch + 1:3d}/{num_epochs:3d}, Step: {step + 1:5d}/{NUM_STEPS}, Accuracy: {accuracy * 100:.4f} %')
                print('=========================================')
