In [59]:
import time
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.profiler import profile, ProfilerActivity
from torch.utils.data import DataLoader, Subset
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models.vision_transformer import VisionTransformer

In [60]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.is_available())

2.9.0a0+gitcbe1a35
12.4
None
True


In [61]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
BATCH_SIZE = 64
NUM_EPOCHS = 100
LEARNING_RATE = 3e-4
NUM_WORKERS = 4
IMAGE_SIZE = 32
PATCH_SIZE = 8
HEAD_SIZE = 100
NUM_TRAINING_IMAGES = 50000
NUM_TESTING_IMAGES = 10000

VIT_MODEL = "CUSTOM"
# VIT_MODEL = "BUILT-IN"

# OPTIMIZER = "MUON"
OPTIMIZER = "ADAMW"

cuda


In [62]:
if(VIT_MODEL == "CUSTOM"):
  model = VisionTransformer(
      image_size=IMAGE_SIZE,
      patch_size=PATCH_SIZE,
      num_layers=8,
      hidden_dim=384,
      num_heads=6,
      mlp_dim=1536,
      num_classes=100,
    )

else:
  # model = vit_b_16(pretrained=False)
  model = vit_b_16(weights="IMAGENET1K_V1")

In [63]:
cifar100_mean = [0.5071, 0.4867, 0.4408]
cifar100_std  = [0.2675, 0.2565, 0.2761]

train_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar100_mean, std=cifar100_std),
])

test_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar100_mean, std=cifar100_std),
])

In [64]:
trainset = datasets.CIFAR100(root='./data', train=True,
                                         download=True, transform=train_transform)
train_subset = Subset(trainset, range(NUM_TRAINING_IMAGES))
trainloader = DataLoader(train_subset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=NUM_WORKERS)

testset = datasets.CIFAR100(root='./data', train=False,
                                        download=True, transform=test_transform)
test_subset = Subset(testset, range(NUM_TESTING_IMAGES))
testloader = DataLoader(test_subset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_WORKERS)

dataiter = iter(trainloader)
images, labels = next(dataiter)
print(images.shape)
print(labels[:5])

torch.Size([64, 3, 32, 32])
tensor([34, 17, 82, 26, 20])


In [65]:
model.heads.head = nn.Linear(model.heads.head.in_features, HEAD_SIZE)
model = model.to(device)

criterion = nn.CrossEntropyLoss()

# Initial code uses single optimizer only. Can be uncommented if needed.
# match OPTIMIZER:
#   case "MUON":
#     optimizer = optim.Muon(model.parameters(), lr=LEARNING_RATE)
#   case default:
#     optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [66]:
optimizer_adamw = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

In [None]:
scheduler_adamw = optim.lr_scheduler.CosineAnnealingLR(optimizer_adamw, T_max=NUM_EPOCHS)

best_loss = float('inf')
patience = 5
trigger_times = 0

iteration = 0
start_time = time.time()
for epoch in range(NUM_EPOCHS):
    print("Epoch {}/{}".format(epoch + 1, NUM_EPOCHS))
    model.train()
    total_loss = 0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer_adamw.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer_adamw.step()

        iteration += 1

        total_loss += loss.item()

    scheduler_adamw.step()
    
    avg_loss = total_loss / len(trainloader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Loss: {avg_loss:.4f}")
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"No improvement for {trigger_times} epoch(s)")
        
    if trigger_times >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break
        
end_time = time.time()
runtime = end_time - start_time
print(f"Start Time: {start_time}")
print(f"End Time: {end_time}")
print(f"Runtime: {runtime:.6f} seconds")

Epoch 1/100
Epoch [1/100] Loss: 3.9706
Epoch 2/100
Epoch [2/100] Loss: 3.5693
Epoch 3/100
Epoch [3/100] Loss: 3.3632
Epoch 4/100


In [None]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

acc = 100 * correct / total
print(f"Test Accuracy: {acc:.2f}%")