In [1]:
import csv
import time
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.profiler import profile, ProfilerActivity
from torch.utils.data import DataLoader, Subset
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models.vision_transformer import VisionTransformer

In [2]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.is_available())

2.9.0a0+gitcbe1a35
12.4
None
True


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
BATCH_SIZE = 64
NUM_EPOCHS = 5
LEARNING_RATE = 3e-4
NUM_WORKERS = 4
IMAGE_SIZE = 224
HEAD_SIZE = 100
NUM_TRAINING_IMAGES = 2000
NUM_TESTING_IMAGES = 200

VIT_MODEL = "CUSTOM"
# VIT_MODEL = "BUILT-IN"

# OPTIMIZER = "MUON"
OPTIMIZER = "ADAMW"

cuda


In [4]:
if(VIT_MODEL == "CUSTOM"):
  model = VisionTransformer(
      image_size=IMAGE_SIZE,
      patch_size=16,
      num_layers=8,
      hidden_dim=768,
      num_heads=12,
      mlp_dim=3072,
      num_classes=100,
    )
else:
  model = vit_b_16(pretrained=False)

In [5]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

trainset = datasets.CIFAR100(root='./data', train=True,
                                         download=True, transform=transform)
train_subset = Subset(trainset, range(NUM_TRAINING_IMAGES))
trainloader = DataLoader(train_subset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=NUM_WORKERS)

testset = datasets.CIFAR100(root='./data', train=False,
                                        download=True, transform=transform)
test_subset = Subset(testset, range(NUM_TESTING_IMAGES))
testloader = DataLoader(test_subset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_WORKERS)

dataiter = iter(trainloader)
images, labels = next(dataiter)
print(images.shape)
print(labels[:5])

torch.Size([64, 3, 224, 224])
tensor([68, 11, 45,  4, 65])


In [6]:
model.heads.head = nn.Linear(model.heads.head.in_features, HEAD_SIZE)
model = model.to(device)

criterion = nn.CrossEntropyLoss()

# Initial code uses single optimizer only. Can be uncommented if needed.
# match OPTIMIZER:
#   case "MUON":
#     optimizer = optim.Muon(model.parameters(), lr=LEARNING_RATE)
#   case default:
#     optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [7]:
params_2d = []
params_other = []

for name, param in model.named_parameters():
    if param.requires_grad:
        if param.ndim == 2:
            params_2d.append(param)
        else:
            params_other.append(param)

optimizer_muon = optim.Muon(params_2d, lr=LEARNING_RATE)
optimizer_adamw = optim.AdamW(params_other, lr=LEARNING_RATE)

In [8]:
csv_file = open("./optimizer_kernels.csv", "w", newline="")
writer = csv.writer(csv_file)
writer.writerow(["Iteration", "Kernel", "CUDA time (us)", "CPU time (us)", "Calls"])

iteration = 0
start_time = time.time()
for epoch in range(NUM_EPOCHS):
    print("Epoch {}/{}".format(epoch + 1, NUM_EPOCHS))
    model.train()
    total_loss = 0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer_muon.zero_grad()
        optimizer_adamw.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=False,
            profile_memory=False
        ) as prof:
            optimizer_muon.step()
        optimizer_adamw.step()

        iteration += 1

        for item in prof.key_averages():
            writer.writerow([
                iteration,
                item.key,
                item.cuda_time,
                item.cpu_time,
                item.count
            ])
        csv_file.flush()

        total_loss += loss.item()

    avg_loss = total_loss / len(trainloader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Loss: {avg_loss:.4f}")
end_time = time.time()
runtime = end_time - start_time
print(f"Start Time: {start_time}")
print(f"End Time: {end_time}")
print(f"Runtime: {runtime:.6f} seconds")

csv_file.close()

Epoch 1/5


  item.cuda_time,


Epoch [1/5] Loss: 4.6621
Epoch 2/5
Epoch [2/5] Loss: 4.5396
Epoch 3/5
Epoch [3/5] Loss: 4.4350
Epoch 4/5
Epoch [4/5] Loss: 4.2973
Epoch 5/5
Epoch [5/5] Loss: 4.1679
Start Time: 1761766762.6345208
End Time: 1761766862.781294
Runtime: 100.146773 seconds


In [9]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

acc = 100 * correct / total
print(f"Test Accuracy: {acc:.2f}%")

Test Accuracy: 5.00%
