In [None]:
import torch
import torch.nn.functional as F

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

## Model

In [None]:
patch_size = 4
num_patches = (32 // patch_size)**2

In [None]:
from vision_transformer import VisionTransformer

vit = VisionTransformer(
    in_channels=3,
    num_patches=num_patches,
    patch_size=patch_size,
    embed_dim=768,
    num_heads=12,
    num_layers=12,
    num_classes=10,
    dropout=0.1
).to(device)

In [None]:
from torchsummary import summary

summary(vit, (3, 32, 32), device=device)

## Data

In [None]:
def unpickle(file):
    import pickle
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict

In [None]:
train_batches = ["data/data_batch_1", "data/data_batch_2", "data/data_batch_3", "data/data_batch_4", "data/data_batch_5"]
X_train = torch.concat([torch.tensor(unpickle(p)[b'data']) for p in train_batches], dim=0).view(-1, 3, 32, 32) / 255.0
y_train = torch.concat([torch.tensor(unpickle(p)[b'labels']) for p in train_batches], dim=0)
X_test = torch.tensor(unpickle("data/test_batch")[b'data']).view(-1, 3, 32, 32) / 255.0
y_test = torch.tensor(unpickle("data/test_batch")[b'labels'])

labels = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

In [None]:
from random import randint
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))

for i in range(16):
    plt.subplot(4, 4, i+1)
    idx = randint(0, len(X_train))
    plt.imshow(X_train[idx].permute(1, 2, 0))
    plt.title(labels[y_train[idx].item()])
    plt.axis("off")

In [None]:
rand_idx = torch.randperm(len(X_train))
X_train_shuffled, y_train_shuffled = X_train[rand_idx], y_train[rand_idx]
n = int(0.8 * len(X_train))

X_train = X_train_shuffled[:n]
y_train = y_train_shuffled[:n]
X_val = X_train_shuffled[n:]
y_val = y_train_shuffled[n:]

print(X_train.shape, y_train.shape)
print(X_val.shape, y_val.shape)
print(X_test.shape, y_test.shape)

## Training

In [None]:
epochs = 10
batch_size = 64
lr = 3e-3
betas = (0.9, 0.999)
weight_decay = 0.1

In [None]:
from torch.utils.data import DataLoader, TensorDataset

optim = torch.optim.Adam(vit.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
train_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_dl = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=True)

train_steps = len(train_dl)

In [None]:
for e in range(1, epochs + 1):

    vit.train()
    for step, batch in enumerate(train_dl):
        print(f"step {step}/{train_steps}", end="\r")
        X = batch[0].to(device)
        y = batch[1].to(device)

        logits = vit(X)
        train_loss = F.cross_entropy(logits, y)

        train_loss.backward()
        optim.step()
        optim.zero_grad()

    vit.eval()
    val_loss = val_accuracy = 0.0
    for batch in val_dl:
        X = batch[0].to(device)
        y = batch[1].to(device)
        logits = vit(X)
        val_loss += F.cross_entropy(logits, y).item()
        val_accuracy += ((F.softmax(logits, dim=-1).argmax(-1) == y).sum() / X.shape[0]).item()


    val_loss /= len(val_dl)
    val_accuracy /= len(val_dl)

    print(f"epoch {e}/{epochs} | train_loss {train_loss.item():.4f} | val_loss {val_loss:.4f} | val_acc {val_accuracy:.4f}")