In [1]:
import torchvision
import torchvision.transforms as transforms
import torch

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 100

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
from vit_pytorch import ViT

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

net = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=256,
    depth=3,
    heads=4,
    mlp_dim=256,
    dropout=0.1,
    emb_dropout=0.1
).to(device)

In [8]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

epochs = 20
for epoch in range(0, epochs):
    epoch_train_loss = 0
    epoch_train_acc = 0
    epoch_test_loss = 0
    epoch_test_acc = 0

    net.train()
    for data in train_loader:
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()/len(train_loader)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        epoch_train_acc += acc/len(train_loader)

    net.eval()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            epoch_test_loss += loss.item()/len(test_loader)
            test_acc = (outputs.argmax(dim=1) == labels).float().mean()
            epoch_test_acc += test_acc/len(test_loader)

    print(f'Epoch {epoch+1} : train acc. {epoch_train_acc:.2f} train loss {epoch_train_loss:.2f}')
    print(f'Epoch {epoch+1} : test acc. {epoch_test_acc:.2f} test loss {epoch_test_loss:.2f}')

Epoch 1 : train acc. 0.20 train loss 2.12
Epoch 1 : test acc. 0.24 test loss 2.05
Epoch 2 : train acc. 0.23 train loss 2.03
Epoch 2 : test acc. 0.25 test loss 1.99
Epoch 3 : train acc. 0.26 train loss 1.97
Epoch 3 : test acc. 0.29 test loss 1.92
Epoch 4 : train acc. 0.28 train loss 1.90
Epoch 4 : test acc. 0.32 test loss 1.84
Epoch 5 : train acc. 0.31 train loss 1.84
Epoch 5 : test acc. 0.34 test loss 1.82
Epoch 6 : train acc. 0.33 train loss 1.79
Epoch 6 : test acc. 0.35 test loss 1.75
Epoch 7 : train acc. 0.35 train loss 1.74
Epoch 7 : test acc. 0.38 test loss 1.70
Epoch 8 : train acc. 0.38 train loss 1.69
Epoch 8 : test acc. 0.39 test loss 1.67
Epoch 9 : train acc. 0.40 train loss 1.64
Epoch 9 : test acc. 0.41 test loss 1.63
Epoch 10 : train acc. 0.42 train loss 1.59
Epoch 10 : test acc. 0.43 test loss 1.60
Epoch 11 : train acc. 0.43 train loss 1.55
Epoch 11 : test acc. 0.45 test loss 1.54
Epoch 12 : train acc. 0.45 train loss 1.51
Epoch 12 : test acc. 0.46 test loss 1.53
Epoch 13 :