In [1]:
import torch, timm
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.utils import ModelEmaV2
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(2,9),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

trainset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
testset  = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader  = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170M/170M [04:04<00:00, 698kB/s]  


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
model = timm.create_model(
    'vit_base_patch16_224',
    pretrained=False,
    num_classes=10,
    img_size=32
).to(device)

ema = ModelEmaV2(model, decay=0.9999)
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

In [4]:
model = timm.create_model(
    'vit_small_patch16_224',
    pretrained=False,
    num_classes=10,
    img_size=32,
    patch_size=4
).to(device)

ema = ModelEmaV2(model, decay=0.9999)
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
  

In [5]:
mixup = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, prob=1.0, switch_prob=0.5, num_classes=10)
criterion = SoftTargetCrossEntropy()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [6]:
def evaluate(model):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            pred = model(x).softmax(1).argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

In [None]:
for epoch in range(200):
    model.train()

    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        x, y = mixup(x, y)

        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        ema.update(model)

    scheduler.step()
    acc = evaluate(ema.module)
    print(f"Epoch {epoch+1}: Test Acc = {acc:.2f}%")

Epoch 1: Test Acc = 13.71%
Epoch 2: Test Acc = 13.49%
Epoch 3: Test Acc = 13.94%
Epoch 4: Test Acc = 15.31%
Epoch 5: Test Acc = 19.74%
Epoch 6: Test Acc = 21.11%
Epoch 7: Test Acc = 20.25%
Epoch 8: Test Acc = 20.09%
Epoch 9: Test Acc = 20.70%
Epoch 10: Test Acc = 21.52%
Epoch 11: Test Acc = 21.91%
Epoch 12: Test Acc = 21.91%
Epoch 13: Test Acc = 21.11%
Epoch 14: Test Acc = 21.29%
Epoch 15: Test Acc = 21.84%
Epoch 16: Test Acc = 21.96%
Epoch 17: Test Acc = 22.19%
Epoch 18: Test Acc = 22.35%
Epoch 19: Test Acc = 22.87%
Epoch 20: Test Acc = 23.11%
Epoch 21: Test Acc = 23.58%
Epoch 22: Test Acc = 24.10%
Epoch 23: Test Acc = 24.79%
Epoch 24: Test Acc = 25.19%
Epoch 25: Test Acc = 26.01%
Epoch 26: Test Acc = 26.78%
Epoch 27: Test Acc = 27.75%
Epoch 28: Test Acc = 29.12%
Epoch 29: Test Acc = 30.43%
Epoch 30: Test Acc = 32.30%
Epoch 31: Test Acc = 34.10%
Epoch 32: Test Acc = 35.63%
Epoch 33: Test Acc = 37.06%
Epoch 34: Test Acc = 38.69%
Epoch 35: Test Acc = 40.12%
Epoch 36: Test Acc = 41.50%
E