In [None]:
!pip install vit_pytorch



In [None]:
import torch
from vit_pytorch import ViT
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import pandas as pd

In [None]:
#MNIST
img_size = 28
batch_size = 256
#transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

mnist_train = MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = MNIST("./data", train=False, download=True, transform=transform)


train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

net = ViT(
image_size=img_size,
patch_size=4,
num_classes=10,
dim=256,
depth=3,
heads=4,
mlp_dim=256,
dropout=0.1,
emb_dropout=0.1
).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

record_loss_train = []
record_loss_test = []

In [None]:

epochs = 20
for epoch in range(0, epochs):
  epoch_train_loss = 0
  epoch_train_acc = 0
  epoch_test_loss = 0
  epoch_test_acc = 0

  net.train()
  for data in train_loader:
    inputs, labels = data[0].to(device), data[1].to(device)
    #print("Input shape:", inputs.shape)
    optimizer.zero_grad()

    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    epoch_train_loss += loss.item()/len(train_loader)
    acc = (outputs.argmax(dim=1) == labels).float().mean()
    epoch_train_acc += acc/len(train_loader)

  net.eval()
  with torch.no_grad():
    for data in test_loader:
      inputs, labels = data[0].to(device), data[1].to(device)
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      epoch_test_loss += loss.item()/len(test_loader)
      test_acc = (outputs.argmax(dim=1) == labels).float().mean()
      epoch_test_acc += test_acc/len(test_loader)

  print(f'Epoch {epoch+1} : train acc. {epoch_train_acc:.2f} train loss {epoch_train_loss:.2f}')
  print(f'Epoch {epoch+1} : test acc. {epoch_test_acc:.2f} test loss {epoch_test_loss:.2f}')

Epoch 1 : train acc. 0.16 train loss 2.28
Epoch 1 : test acc. 0.17 test loss 2.23
Epoch 2 : train acc. 0.20 train loss 2.21
Epoch 2 : test acc. 0.24 test loss 2.11
Epoch 3 : train acc. 0.25 train loss 2.05
Epoch 3 : test acc. 0.34 test loss 1.87
Epoch 4 : train acc. 0.35 train loss 1.82
Epoch 4 : test acc. 0.52 test loss 1.52
Epoch 5 : train acc. 0.55 train loss 1.38
Epoch 5 : test acc. 0.69 test loss 0.91
Epoch 6 : train acc. 0.69 train loss 0.94
Epoch 6 : test acc. 0.79 test loss 0.65
Epoch 7 : train acc. 0.77 train loss 0.71
Epoch 7 : test acc. 0.84 test loss 0.51
Epoch 8 : train acc. 0.82 train loss 0.56
Epoch 8 : test acc. 0.88 test loss 0.41
Epoch 9 : train acc. 0.85 train loss 0.47
Epoch 9 : test acc. 0.90 test loss 0.35
Epoch 10 : train acc. 0.87 train loss 0.42
Epoch 10 : test acc. 0.91 test loss 0.31
Epoch 11 : train acc. 0.88 train loss 0.38
Epoch 11 : test acc. 0.92 test loss 0.28
Epoch 12 : train acc. 0.89 train loss 0.35
Epoch 12 : test acc. 0.93 test loss 0.26
Epoch 13 :

In [None]:
for i in range(5):
  net.train()
  loss_train = 0
  for j, (x, t) in enumerate(train_loader):
    x, t = x.cuda(), t.cuda()
    print("Input shape:", x.shape)
    y = net(x)