In [10]:
import torch
from torch import nn
from VisionTransformers.ViT import VisionTransformer
from VisionTransformers.utils import ViTDsetPytorch, train_model, eval_model
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
cifar_Dset = CIFAR10('.', train=True, download=True, 
                     transform=transforms.Compose([transforms.ToTensor(),
                                                   transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                                                   ]))
cifar_testDset = CIFAR10('.', train=False, download=False, 
                         transform=transforms.Compose([transforms.ToTensor(),
                                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                                                       ]))

In [11]:
epochs = 100
batch_size= 1024
img_size = 72
patch_size = 16
num_patches = (img_size//patch_size)**2
num_features = 3*patch_size*patch_size

Dset = ViTDsetPytorch(cifar_Dset, patch_size=patch_size, img_size=img_size)
test_Dset = ViTDsetPytorch(cifar_testDset, patch_size=patch_size, img_size=img_size)
Dloader = DataLoader(Dset, batch_size=batch_size, shuffle=True, num_workers=2)
test_Dloader = DataLoader(test_Dset, batch_size=batch_size, shuffle=True, num_workers=2)

In [7]:
lr = 0.001
ViTModel = VisionTransformer(in_features=num_features, num_patches=num_patches, num_classes=10, 
                             proj_dim=64, num_heads=4, num_layers=8, dropout=0.5).to(device)
crit = nn.CrossEntropyLoss()
optim = torch.optim.Adam(ViTModel.parameters(), lr=lr)

In [8]:
for epoch in range(epochs):
  train_model(ViTModel, Dloader, crit, optim, epoch, epochs, device)
  if epoch%5==4:
    eval_model(ViTModel, test_Dloader, crit, device)

[epoch: 1/100] loss: 2.0216 acc:0.23832
[epoch: 2/100] loss: 1.7870 acc:0.33046
[epoch: 3/100] loss: 1.6183 acc:0.39896
[epoch: 4/100] loss: 1.5037 acc:0.44714
[epoch: 5/100] loss: 1.4212 acc:0.4809
------------------------------------------------------------------------------------------
[Test eval] loss: 1.3635 acc: 0.50450000
------------------------------------------------------------------------------------------
[epoch: 6/100] loss: 1.3581 acc:0.51024
[epoch: 7/100] loss: 1.3156 acc:0.52272
[epoch: 8/100] loss: 1.2739 acc:0.54072
[epoch: 9/100] loss: 1.2404 acc:0.55054
[epoch: 10/100] loss: 1.2044 acc:0.56626
------------------------------------------------------------------------------------------
[Test eval] loss: 1.2146 acc: 0.56290000
------------------------------------------------------------------------------------------
[epoch: 11/100] loss: 1.1727 acc:0.57604
[epoch: 12/100] loss: 1.1443 acc:0.58602
[epoch: 13/100] loss: 1.1202 acc:0.59742
[epoch: 14/100] loss: 1.1003 ac