# An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Daniel Kofler 2024

In [None]:
import torch
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
device

## Model

In [None]:
batch_size = 256
patch_size = 4
num_patches = (32 // patch_size)**2

In [None]:
from vision_transformer import VisionTransformer

vit = VisionTransformer(
    in_channels=3,
    num_patches=num_patches,
    patch_size=patch_size,
    embed_dim=384,
    num_heads=6,
    num_layers=6,
    num_classes=10,
    dropout=0.3
).to(device)

In [None]:
from torchsummary import summary

summary(vit, (3, 32, 32), batch_size=batch_size, device=device)

## Data

In [None]:
def unpickle(file):
    import pickle
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict

In [None]:
train_batches = ["data/data_batch_1", "data/data_batch_2", "data/data_batch_3", "data/data_batch_4", "data/data_batch_5"]
X_train = torch.concat([torch.tensor(unpickle(p)[b'data']) for p in train_batches], dim=0).view(-1, 3, 32, 32) / 255.0
y_train = torch.concat([torch.tensor(unpickle(p)[b'labels']) for p in train_batches], dim=0)
X_test = torch.tensor(unpickle("data/test_batch")[b'data']).view(-1, 3, 32, 32) / 255.0
y_test = torch.tensor(unpickle("data/test_batch")[b'labels'])

labels = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

In [None]:
from random import randint
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))

for i in range(16):
    plt.subplot(4, 4, i+1)
    idx = randint(0, len(X_train))
    plt.imshow(X_train[idx].permute(1, 2, 0))
    plt.title(labels[y_train[idx].item()])
    plt.axis("off")

In [None]:
rand_idx = torch.randperm(len(X_train))
X_train_shuffled, y_train_shuffled = X_train[rand_idx], y_train[rand_idx]
n = int(0.8 * len(X_train))

X_train = X_train_shuffled[:n]
y_train = y_train_shuffled[:n]
X_val = X_train_shuffled[n:]
y_val = y_train_shuffled[n:]

print(X_train.shape, y_train.shape)
print(X_val.shape, y_val.shape)
print(X_test.shape, y_test.shape)

## Training

In [None]:
epochs = 100
lr = 3e-4

In [None]:
from torch.utils.data import DataLoader, TensorDataset

optim = torch.optim.AdamW(vit.parameters(), lr=lr)
train_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_dl = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=True)

train_steps = len(train_dl)

In [None]:
def accuracy(logits, target) -> float:
    with torch.no_grad():
        return (F.softmax(logits, dim=-1).argmax(-1) == target).float().mean().item()

In [None]:
ckpt_interval = 25

for e in range(1, epochs + 1):

    # training
    vit.train()
    train_loss = train_accuracy = 0.0
    for step, batch in enumerate(train_dl):
        print(f"step {step}/{train_steps}", end="\r")
        X = batch[0].to(device)
        y = batch[1].to(device)

        logits = vit(X)
        loss = F.cross_entropy(logits, y)
        train_loss += loss.item()
        train_accuracy += accuracy(logits, y)

        loss.backward()
        optim.step()
        optim.zero_grad()

    train_loss /= len(train_dl)
    train_accuracy /= len(train_dl)

    # validation
    vit.eval()
    val_loss = val_accuracy = 0.0
    with torch.no_grad():
        for batch in val_dl:
            X = batch[0].to(device)
            y = batch[1].to(device)
            logits = vit(X)
            val_loss += F.cross_entropy(logits, y).item()
            val_accuracy += accuracy(logits, y)

    val_loss /= len(val_dl)
    val_accuracy /= len(val_dl)

    # save checkpoints
    if e > 1 and e % ckpt_interval == 0:
        sd = {
            "model": vit.state_dict(),
            "optim": optim.state_dict(),
            "epoch" : e
        }
        torch.save(sd, f"vit_{e}.pt")

    print(f"epoch {e}/{epochs} | train_loss {train_loss:.4f} | train_acc {train_accuracy:.4f} | val_loss {val_loss:.4f} | val_acc {val_accuracy:.4f}")