In [27]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from tqdm import tqdm

from dataset import ToRGB, PatchCollate
from model import ViT

In [28]:
TRAIN_SIZE = 0.8
TEST_SIZE = 1 - TRAIN_SIZE
PATCH_SIZE = 16
BATCH_SIZE = 32
IMG_SIZE = 256

D_MODEL = 768
NUM_HEADS = 8
NUM_LAYERS = 6
DROPOUT_P = 0.1
IN_CHANNELS = 3
PATCH_SIZE = 16
NUM_CLASSES = 257
EPOCHS = 10
LR = 1e-4

DEV = torch.device("mps")

In [29]:
T = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    ToRGB()
])

dataset = datasets.Caltech256("./dataset", download=True, transform=T)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [TRAIN_SIZE, TEST_SIZE])
class_to_idx = {c: i for i, c in enumerate(dataset.categories)}
idx_to_class = {value: key for key, value in class_to_idx.items()}

patch_collate_fn = PatchCollate(PATCH_SIZE, PATCH_SIZE)
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=patch_collate_fn)
test_loader = DataLoader(test_dataset, BATCH_SIZE * 2, collate_fn=patch_collate_fn)

Files already downloaded and verified


In [30]:
model = ViT(
    in_channels=IN_CHANNELS,
    patch_size=PATCH_SIZE,
    img_size=IMG_SIZE,
    num_classes=NUM_CLASSES,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    n_layers=NUM_LAYERS
).to(DEV)
opt = optim.Adam(model.parameters(), lr=LR)
crit = nn.CrossEntropyLoss()

In [32]:
for e in range(1, EPOCHS + 1):
    model = model.train()
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True, position=0)
    loop.set_description(f"Train Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    for i, (imgs, labels) in loop:
        imgs, labels = imgs.to(DEV), labels.to(DEV)
        opt.zero_grad()
        yhat = model(imgs)
        loss = crit(yhat, labels)
        loss.backward()
        opt.step()

        total_loss += loss.item()
        loop.set_postfix(loss = total_loss / (i + 1))

    model = model.eval()
    test_loop = tqdm(test_loader, total=len(test_loader), position=0, leave=True)
    test_loop.set_description(f"Test Epoch : [{e}/{EPOCHS}]")
    total_correct = 0
    for imgs, labels in test_loop:
        imgs, labels = imgs.to(DEV), labels.to(DEV)
        with torch.no_grad():
            yhat = model(imgs).argmax(dim = -1)
            total_correct += (yhat == labels).sum().item()

    print(f"Test accuracy : {total_correct / len(test_dataset)}")
    torch.save(model.state_dict(), "vit.pth")

Train Epoch : [1/10]:   0%|          | 0/766 [00:05<?, ?it/s]
Test Epoch : [1/10]: 100%|██████████| 96/96 [01:18<00:00,  1.22it/s]


Test accuracy : 0.003920927952948864


Train Epoch : [2/10]:   0%|          | 0/766 [00:01<?, ?it/s]
Test Epoch : [2/10]:   5%|▌         | 5/96 [00:05<01:38,  1.09s/it]


KeyboardInterrupt: 