In [1]:
import csv
import time
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.profiler import profile, ProfilerActivity
from torch.utils.data import DataLoader, Subset
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models.vision_transformer import VisionTransformer

In [2]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.is_available())

2.9.0a0+gitcbe1a35
12.4
None
True


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
BATCH_SIZE = 64
NUM_EPOCHS = 200
LEARNING_RATE = 3e-4
NUM_WORKERS = 4
IMAGE_SIZE = 32
PATCH_SIZE = 4
HEAD_SIZE = 100
NUM_TRAINING_IMAGES = 45000
NUM_VAL_IMAGES = 5000
NUM_TESTING_IMAGES = 10000

VIT_MODEL = "CUSTOM"
# VIT_MODEL = "BUILT-IN"

# OPTIMIZER = "MUON"
OPTIMIZER = "ADAMW"

cuda


In [4]:
if(VIT_MODEL == "CUSTOM"):
  model = VisionTransformer(
      image_size=IMAGE_SIZE,
      patch_size=PATCH_SIZE,
      num_layers=8,
      hidden_dim=384,
      num_heads=6,
      mlp_dim=1536,
      num_classes=100,
    )

else:
  # model = vit_b_16(pretrained=False)
  model = vit_b_16(weights="IMAGENET1K_V1")

In [5]:
cifar100_mean = [0.5071, 0.4867, 0.4408]
cifar100_std  = [0.2675, 0.2565, 0.2761]

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar100_mean, std=cifar100_std),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar100_mean, std=cifar100_std),
])

In [6]:
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)

train_subset, val_subset = torch.utils.data.random_split(
    trainset, [NUM_TRAINING_IMAGES, NUM_VAL_IMAGES],
    generator=torch.Generator().manual_seed(42)
)

trainloader = DataLoader(train_subset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=NUM_WORKERS)
valloader = DataLoader(val_subset, batch_size=BATCH_SIZE,
                                           shuffle=False, num_workers=NUM_WORKERS)

testset = datasets.CIFAR100(root='./data', train=False,
                                        download=True, transform=test_transform)
test_subset = Subset(testset, range(NUM_TESTING_IMAGES))
testloader = DataLoader(test_subset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_WORKERS)

dataiter = iter(trainloader)
images, labels = next(dataiter)
print(images.shape)
print(labels[:5])

torch.Size([64, 3, 32, 32])
tensor([62, 13, 52, 94, 64])


In [7]:
model.heads.head = nn.Linear(model.heads.head.in_features, HEAD_SIZE)
model = model.to(device)

criterion = nn.CrossEntropyLoss()

# Initial code uses single optimizer only. Can be uncommented if needed.
# match OPTIMIZER:
#   case "MUON":
#     optimizer = optim.Muon(model.parameters(), lr=LEARNING_RATE)
#   case default:
#     optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [8]:
optimizer_adamw = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)

In [None]:
# csv_file = open("./optimizer_kernels.csv", "w", newline="")
# writer = csv.writer(csv_file)
# writer.writerow(["Iteration", "Kernel", "CUDA time (us)", "CPU time (us)", "Calls"])
scheduler_adamw = optim.lr_scheduler.CosineAnnealingLR(optimizer_adamw, T_max=NUM_EPOCHS)

best_loss = float('inf')
patience = 5
trigger_times = 0

iteration = 0
start_time = time.time()
for epoch in range(NUM_EPOCHS):
    print("Epoch {}/{}".format(epoch + 1, NUM_EPOCHS))
    
    model.train()
    running_loss = 0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer_adamw.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        # with profile(
        #     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        #     record_shapes=False,
        #     profile_memory=False
        # ) as prof:
        optimizer_adamw.step()

        # for item in prof.key_averages():
        #     writer.writerow([
        #         iteration,
        #         item.key,
        #         item.cuda_time,
        #         item.cpu_time,
        #         item.count
        #     ])
        # csv_file.flush()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(trainloader)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in valloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(valloader)

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    
    scheduler_adamw.step()
    
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"No improvement for {trigger_times} epoch(s)")
        
    if trigger_times >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break
end_time = time.time()
runtime = end_time - start_time
print(f"Start Time: {start_time}")
print(f"End Time: {end_time}")
print(f"Runtime: {runtime:.6f} seconds")

# csv_file.close()

Epoch 1/200
Epoch [1/200] Train Loss: 3.9425 | Val Loss: 3.5769
Epoch 2/200
Epoch [2/200] Train Loss: 3.4358 | Val Loss: 3.2845
Epoch 3/200
Epoch [3/200] Train Loss: 3.1606 | Val Loss: 3.0641
Epoch 4/200


In [None]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

acc = 100 * correct / total
print(f"Test Accuracy: {acc:.2f}%")