In [1]:
from torchvision.datasets import CIFAR100
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision.models.vision_transformer import VisionTransformer
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

In [2]:
torch.cuda.is_available()

True

In [3]:
BATCH_SIZE = 256
NUM_WORKERS = 2
IMAGE_SIZE = 64

In [4]:
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train = CIFAR100(root="./data", train=True, download=True, transform=train_transforms)
test  = CIFAR100(root="./data", train=False, download=True, transform=test_transforms)

In [5]:
train_loader = DataLoader(
    train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)
test_loader = DataLoader(
    test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)

In [6]:
device = "cuda"

In [7]:
def accuracy_top1(logits, targets):
  preds = logits.argmax(dim=1)
  return (preds == targets).float().mean().item()


@torch.no_grad()
def evaluate(model, loader, criterion):
  model.eval()
  total_loss, total_acc, n = 0.0, 0.0, 0
  for x, y in loader:
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
    logits = model(x)
    loss = criterion(logits, y)
    bs = x.size(0)
    total_loss += loss.item() * bs
    total_acc  += accuracy_top1(logits, y) * bs
    n += bs
  return total_loss / n, total_acc / n

# AdamW

In [8]:
# Hyperparameters
EPOCHS = 20
LR_ADAMW = 2e-4
WD_ADAMW = 0.05

In [9]:
model = VisionTransformer(
  image_size=IMAGE_SIZE,
  patch_size=8, # 64 / 8 = 8 patches per side -> 64 tokens
  num_layers=12,
  num_heads=6,
  hidden_dim=384,
  mlp_dim=1536,
  dropout=0.1,
  attention_dropout=0.1,
  num_classes=100
)
model = model.to(device)

In [10]:
optimizer = optim.AdamW(
  model.parameters(),
  lr=LR_ADAMW,
  weight_decay=WD_ADAMW
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler(enabled=(device=="cuda"))
criterion = nn.CrossEntropyLoss()

In [11]:
def train_one_epoch_adamw(model, loader, optimizer, scaler, criterion):
  model.train()
  total_loss, total_acc, n = 0.0, 0.0, 0

  for x, y in loader:
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)

    with torch.autocast(device_type=device, dtype=torch.float16, enabled=(device=="cuda")):
      logits = model(x)
      loss = criterion(logits, y)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    bs = x.size(0)
    total_loss += loss.item() * bs
    total_acc  += accuracy_top1(logits, y) * bs
    n += bs

  return total_loss / n, total_acc / n

In [12]:
best_acc = 0.0
patience = 5   # Stop after 5 epochs w/ no meaningful improvement
bad_epochs = 0

for epoch in range(1, EPOCHS+1):
  train_loss, train_acc = train_one_epoch_adamw(
      model,
      train_loader,
      optimizer,
      scaler,
      criterion
    )
  test_loss, test_acc = evaluate(model, test_loader, criterion)
  scheduler.step()

  print(f"Epoch {epoch} | "
        f"Train Loss {train_loss:.4f}, Accuracy {train_acc:.4f} | "
        f"Test Loss {test_loss:.4f}, Accuracy {test_acc:.4f}")

  if test_acc > best_acc:
    best_acc = test_acc
    bad_epochs = 0
    torch.save(model.state_dict(), "vit_cifar100_best_adamw.pt")
  else:
    bad_epochs += 1
    if bad_epochs >= patience:
      print(f"Early stopping at epoch {epoch}. Best Test Accuracy: {best_acc:.4f}")
      break

print("Best Test Accuracy:", best_acc)

Epoch 1 | Train Loss 4.0831, Accuracy 0.0710 | Test Loss 3.8104, Accuracy 0.1079
Epoch 2 | Train Loss 3.6246, Accuracy 0.1345 | Test Loss 3.5032, Accuracy 0.1583
Epoch 3 | Train Loss 3.3302, Accuracy 0.1890 | Test Loss 3.2359, Accuracy 0.2054
Epoch 4 | Train Loss 3.0717, Accuracy 0.2354 | Test Loss 2.9889, Accuracy 0.2453
Epoch 5 | Train Loss 2.8623, Accuracy 0.2731 | Test Loss 2.8160, Accuracy 0.2867
Epoch 6 | Train Loss 2.7059, Accuracy 0.3057 | Test Loss 2.6766, Accuracy 0.3154
Epoch 7 | Train Loss 2.5726, Accuracy 0.3339 | Test Loss 2.5676, Accuracy 0.3374
Epoch 8 | Train Loss 2.4503, Accuracy 0.3549 | Test Loss 2.4383, Accuracy 0.3585
Epoch 9 | Train Loss 2.3329, Accuracy 0.3825 | Test Loss 2.3701, Accuracy 0.3815
Epoch 10 | Train Loss 2.2301, Accuracy 0.4046 | Test Loss 2.3095, Accuracy 0.3964
Epoch 11 | Train Loss 2.1396, Accuracy 0.4241 | Test Loss 2.2120, Accuracy 0.4183
Epoch 12 | Train Loss 2.0406, Accuracy 0.4471 | Test Loss 2.1567, Accuracy 0.4323
Epoch 13 | Train Loss 1.9

# Muon

In [13]:
# Hyperparameters
EPOCHS = 20
LR_ADAMW = 3e-4
WD_ADAMW = 0.05
LR_MUON = 3e-4
WD_MUON = 0.05
MOM_MUON = 0.95

In [14]:
model = VisionTransformer(
  image_size=IMAGE_SIZE,
  patch_size=8, # 64 / 8 = 8 patches per side -> 64 tokens
  num_layers=12,
  num_heads=6,
  hidden_dim=384,
  mlp_dim=1536,
  dropout=0.1,
  attention_dropout=0.1,
  num_classes=100
)
model = model.to(device)

In [15]:
muon_params = []
adamw_params = []
for n, p in model.named_parameters():
  layer_name = n
  if not p.requires_grad:
    continue
  # Parameters should be 2D and not an output layer
  if p.ndim == 2 and "heads" not in layer_name:
    muon_params.append(p)
  else:
    adamw_params.append(p)

opt_muon = optim.Muon(
  muon_params,
  lr=LR_MUON,
  weight_decay=WD_MUON,
  momentum=MOM_MUON,
  nesterov=True,
  adjust_lr_fn="match_rms_adamw"
)
opt_adamw = optim.AdamW(
    adamw_params,
    lr=LR_ADAMW,
    weight_decay=WD_ADAMW
  )

sch_muon = optim.lr_scheduler.CosineAnnealingLR(opt_muon, T_max=EPOCHS)
sch_other = optim.lr_scheduler.CosineAnnealingLR(opt_adamw, T_max=EPOCHS)

scaler = torch.amp.GradScaler(enabled=(device=="cuda"))
criterion = nn.CrossEntropyLoss()

In [16]:
def train_one_epoch_muon(model, loader, opt_muon, opt_other, scaler, criterion):
  model.train()
  total_loss, total_acc, n = 0.0, 0.0, 0

  for x, y in loader:
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)

    opt_muon.zero_grad(set_to_none=True)
    opt_other.zero_grad(set_to_none=True)

    with torch.autocast(device_type=device, dtype=torch.float16, enabled=(device=="cuda")):
      logits = model(x)
      loss = criterion(logits, y)

    scaler.scale(loss).backward()
    scaler.step(opt_muon)
    scaler.step(opt_other)
    scaler.update()

    bs = x.size(0)
    total_loss += loss.item() * bs
    total_acc  += (logits.argmax(1) == y).float().sum().item()
    n += bs

  return total_loss / n, total_acc / n

In [17]:
best_acc = 0.0
patience = 5  # Stop after 5 epochs w/ no meaningful improvement
bad_epochs = 0

for epoch in range(1, EPOCHS+1):
  train_loss, train_acc = train_one_epoch_muon(
      model,
      train_loader,
      opt_muon,
      opt_adamw,
      scaler,
      criterion
  )
  test_loss, test_acc = evaluate(model, test_loader, criterion)

  sch_muon.step()
  sch_other.step()

  print(f"Epoch {epoch} | "
        f"Train Loss {train_loss:.4f}, Train Accuracy {train_acc:.4f} | "
        f"Test Loss {test_loss:.4f}, Test Accuracy {test_acc:.4f}")

  if test_acc > best_acc:
    best_acc = test_acc
    bad_epochs = 0
    torch.save(model.state_dict(), "vit_cifar100_best_muon.pt")
  else:
    bad_epochs += 1
    if bad_epochs >= patience:
      print(f"Early stopping at epoch {epoch}. Best Test Accuracy: {best_acc:.4f}")
      break

print("Best Test Accuracy:", best_acc)

Epoch 1 | Train Loss 3.9198, Train Accuracy 0.1045 | Test Loss 3.4552, Test Accuracy 0.1802
Epoch 2 | Train Loss 3.2229, Train Accuracy 0.2136 | Test Loss 2.9700, Test Accuracy 0.2647
Epoch 3 | Train Loss 2.8142, Train Accuracy 0.2850 | Test Loss 2.6394, Test Accuracy 0.3226
Epoch 4 | Train Loss 2.5340, Train Accuracy 0.3440 | Test Loss 2.3993, Test Accuracy 0.3780
Epoch 5 | Train Loss 2.3101, Train Accuracy 0.3885 | Test Loss 2.2471, Test Accuracy 0.4044
Epoch 6 | Train Loss 2.1254, Train Accuracy 0.4286 | Test Loss 2.1298, Test Accuracy 0.4307
Epoch 7 | Train Loss 1.9541, Train Accuracy 0.4652 | Test Loss 2.0005, Test Accuracy 0.4659
Epoch 8 | Train Loss 1.8060, Train Accuracy 0.5009 | Test Loss 1.9028, Test Accuracy 0.4852
Epoch 9 | Train Loss 1.6658, Train Accuracy 0.5336 | Test Loss 1.8177, Test Accuracy 0.5055
Epoch 10 | Train Loss 1.5351, Train Accuracy 0.5681 | Test Loss 1.7778, Test Accuracy 0.5192
Epoch 11 | Train Loss 1.4111, Train Accuracy 0.5984 | Test Loss 1.7541, Test Ac