In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
!pip install vit_pytorch

Collecting vit_pytorch


  Downloading vit_pytorch-1.2.0-py3-none-any.whl (87 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m


Collecting einops>=0.6.0
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.6 kB[0m [31m?[0m eta [36m-:--:--[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m




Installing collected packages: einops, vit_pytorch


Successfully installed einops-0.6.0 vit_pytorch-1.2.0
[0m

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from vit_pytorch import ViT

# Set device to use
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define data transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Resize(32)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Resize(32)
])

# Load CIFAR10 dataset
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Define data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


# define model
model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 64,
    depth = 6,
    heads = 8,
    mlp_dim = 128,
    dropout = 0.1,
    emb_dropout = 0.1
)

# define optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Train the model
num_epochs = 200
for epoch in range(num_epochs):
    # Train the model
    model.train()
    train_loss = 0
    train_correct = 0
    for images, labels in trainloader:
#         images = images.to(device)
#         labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        train_correct += (outputs.argmax(dim=1) == labels).sum().item()

    # Test the model
    model.eval()
    test_loss = 0
    test_correct = 0
    with torch.no_grad():
        for images, labels in testloader:
#             images = images.to(device)
#             labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * images.size(0)
            test_correct += (outputs.argmax(dim=1) == labels).sum().item()

    # Print results for this epoch
    train_loss /= len(trainset)
    train_acc = train_correct / len(trainset)
    test_loss /= len(testset)
    test_acc = test_correct / len(testset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

# Save the trained model
torch.save(model.state_dict(), 'VIT_cifar10.pth')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


Files already downloaded and verified


Epoch [1/200], Train Loss: 1.9365, Train Acc: 0.2649, Test Loss: 1.6988, Test Acc: 0.3578


Epoch [2/200], Train Loss: 1.6570, Train Acc: 0.3889, Test Loss: 1.4528, Test Acc: 0.4726


Epoch [3/200], Train Loss: 1.4921, Train Acc: 0.4515, Test Loss: 1.3813, Test Acc: 0.4967


Epoch [4/200], Train Loss: 1.4101, Train Acc: 0.4834, Test Loss: 1.3017, Test Acc: 0.5202


Epoch [5/200], Train Loss: 1.3583, Train Acc: 0.5089, Test Loss: 1.2349, Test Acc: 0.5513


Epoch [6/200], Train Loss: 1.3234, Train Acc: 0.5193, Test Loss: 1.1650, Test Acc: 0.5740


Epoch [7/200], Train Loss: 1.2857, Train Acc: 0.5362, Test Loss: 1.1764, Test Acc: 0.5738


Epoch [8/200], Train Loss: 1.2638, Train Acc: 0.5427, Test Loss: 1.1803, Test Acc: 0.5677


Epoch [9/200], Train Loss: 1.2356, Train Acc: 0.5547, Test Loss: 1.1352, Test Acc: 0.5891


Epoch [10/200], Train Loss: 1.2100, Train Acc: 0.5655, Test Loss: 1.1417, Test Acc: 0.5997


Epoch [11/200], Train Loss: 1.1862, Train Acc: 0.5709, Test Loss: 1.0885, Test Acc: 0.6172


Epoch [12/200], Train Loss: 1.1686, Train Acc: 0.5810, Test Loss: 1.0393, Test Acc: 0.6299


Epoch [13/200], Train Loss: 1.1481, Train Acc: 0.5881, Test Loss: 1.0676, Test Acc: 0.6185


Epoch [14/200], Train Loss: 1.1273, Train Acc: 0.5929, Test Loss: 1.0390, Test Acc: 0.6280


Epoch [15/200], Train Loss: 1.1217, Train Acc: 0.5991, Test Loss: 1.0001, Test Acc: 0.6415


Epoch [16/200], Train Loss: 1.0874, Train Acc: 0.6100, Test Loss: 1.0654, Test Acc: 0.6290


Epoch [17/200], Train Loss: 1.0824, Train Acc: 0.6113, Test Loss: 1.0079, Test Acc: 0.6439


Epoch [18/200], Train Loss: 1.0603, Train Acc: 0.6204, Test Loss: 1.0175, Test Acc: 0.6368


Epoch [19/200], Train Loss: 1.0481, Train Acc: 0.6253, Test Loss: 0.9902, Test Acc: 0.6516


Epoch [20/200], Train Loss: 1.0367, Train Acc: 0.6296, Test Loss: 0.9225, Test Acc: 0.6772


Epoch [21/200], Train Loss: 1.0230, Train Acc: 0.6353, Test Loss: 0.9456, Test Acc: 0.6686


Epoch [22/200], Train Loss: 1.0097, Train Acc: 0.6390, Test Loss: 0.9050, Test Acc: 0.6837


Epoch [23/200], Train Loss: 1.0014, Train Acc: 0.6430, Test Loss: 0.8828, Test Acc: 0.6869


Epoch [24/200], Train Loss: 0.9871, Train Acc: 0.6455, Test Loss: 0.9269, Test Acc: 0.6748


Epoch [25/200], Train Loss: 0.9762, Train Acc: 0.6529, Test Loss: 0.8863, Test Acc: 0.6923


Epoch [26/200], Train Loss: 0.9610, Train Acc: 0.6573, Test Loss: 0.9052, Test Acc: 0.6802


Epoch [27/200], Train Loss: 0.9579, Train Acc: 0.6599, Test Loss: 0.8357, Test Acc: 0.7070


Epoch [28/200], Train Loss: 0.9405, Train Acc: 0.6638, Test Loss: 0.8667, Test Acc: 0.6891


Epoch [29/200], Train Loss: 0.9359, Train Acc: 0.6673, Test Loss: 0.8651, Test Acc: 0.7026


Epoch [30/200], Train Loss: 0.9238, Train Acc: 0.6731, Test Loss: 0.8983, Test Acc: 0.6904


Epoch [31/200], Train Loss: 0.9235, Train Acc: 0.6728, Test Loss: 0.8245, Test Acc: 0.7064


Epoch [32/200], Train Loss: 0.9149, Train Acc: 0.6760, Test Loss: 0.8222, Test Acc: 0.7084


Epoch [33/200], Train Loss: 0.9051, Train Acc: 0.6769, Test Loss: 0.8842, Test Acc: 0.6951


Epoch [34/200], Train Loss: 0.9034, Train Acc: 0.6820, Test Loss: 0.7913, Test Acc: 0.7265


Epoch [35/200], Train Loss: 0.8828, Train Acc: 0.6880, Test Loss: 0.8260, Test Acc: 0.7138


Epoch [36/200], Train Loss: 0.8784, Train Acc: 0.6901, Test Loss: 0.7934, Test Acc: 0.7202


Epoch [37/200], Train Loss: 0.8777, Train Acc: 0.6882, Test Loss: 0.8119, Test Acc: 0.7206


Epoch [38/200], Train Loss: 0.8640, Train Acc: 0.6962, Test Loss: 0.7948, Test Acc: 0.7233


Epoch [39/200], Train Loss: 0.8659, Train Acc: 0.6928, Test Loss: 0.8511, Test Acc: 0.7103


Epoch [40/200], Train Loss: 0.8633, Train Acc: 0.6946, Test Loss: 0.7991, Test Acc: 0.7198


Epoch [41/200], Train Loss: 0.8466, Train Acc: 0.7018, Test Loss: 0.8714, Test Acc: 0.7020


Epoch [42/200], Train Loss: 0.8510, Train Acc: 0.6978, Test Loss: 0.8221, Test Acc: 0.7221


Epoch [43/200], Train Loss: 0.8373, Train Acc: 0.7039, Test Loss: 0.7581, Test Acc: 0.7323


Epoch [44/200], Train Loss: 0.8401, Train Acc: 0.7050, Test Loss: 0.7612, Test Acc: 0.7334


Epoch [45/200], Train Loss: 0.8392, Train Acc: 0.7033, Test Loss: 0.7890, Test Acc: 0.7267


Epoch [46/200], Train Loss: 0.8306, Train Acc: 0.7056, Test Loss: 0.7535, Test Acc: 0.7343


Epoch [47/200], Train Loss: 0.8292, Train Acc: 0.7095, Test Loss: 0.7808, Test Acc: 0.7305


Epoch [48/200], Train Loss: 0.8232, Train Acc: 0.7086, Test Loss: 0.8103, Test Acc: 0.7185


Epoch [49/200], Train Loss: 0.8128, Train Acc: 0.7120, Test Loss: 0.7517, Test Acc: 0.7387


Epoch [50/200], Train Loss: 0.8125, Train Acc: 0.7125, Test Loss: 0.7578, Test Acc: 0.7398


Epoch [51/200], Train Loss: 0.8126, Train Acc: 0.7128, Test Loss: 0.7194, Test Acc: 0.7524


Epoch [52/200], Train Loss: 0.8012, Train Acc: 0.7171, Test Loss: 0.7900, Test Acc: 0.7271


Epoch [53/200], Train Loss: 0.8059, Train Acc: 0.7143, Test Loss: 0.7709, Test Acc: 0.7311


Epoch [54/200], Train Loss: 0.7939, Train Acc: 0.7176, Test Loss: 0.7536, Test Acc: 0.7321


Epoch [55/200], Train Loss: 0.7982, Train Acc: 0.7186, Test Loss: 0.7441, Test Acc: 0.7416


Epoch [56/200], Train Loss: 0.7954, Train Acc: 0.7185, Test Loss: 0.7965, Test Acc: 0.7246


Epoch [57/200], Train Loss: 0.7891, Train Acc: 0.7215, Test Loss: 0.7537, Test Acc: 0.7433


Epoch [58/200], Train Loss: 0.7826, Train Acc: 0.7242, Test Loss: 0.7662, Test Acc: 0.7373


Epoch [59/200], Train Loss: 0.7887, Train Acc: 0.7183, Test Loss: 0.7349, Test Acc: 0.7466


Epoch [60/200], Train Loss: 0.7852, Train Acc: 0.7215, Test Loss: 0.7352, Test Acc: 0.7458


Epoch [61/200], Train Loss: 0.7830, Train Acc: 0.7222, Test Loss: 0.6904, Test Acc: 0.7585


Epoch [62/200], Train Loss: 0.7793, Train Acc: 0.7252, Test Loss: 0.7856, Test Acc: 0.7374


Epoch [63/200], Train Loss: 0.7769, Train Acc: 0.7244, Test Loss: 0.7127, Test Acc: 0.7569


Epoch [64/200], Train Loss: 0.7694, Train Acc: 0.7293, Test Loss: 0.7640, Test Acc: 0.7361


Epoch [65/200], Train Loss: 0.7696, Train Acc: 0.7272, Test Loss: 0.6917, Test Acc: 0.7628


Epoch [66/200], Train Loss: 0.7672, Train Acc: 0.7283, Test Loss: 0.6955, Test Acc: 0.7598


Epoch [67/200], Train Loss: 0.7620, Train Acc: 0.7316, Test Loss: 0.7167, Test Acc: 0.7519


Epoch [68/200], Train Loss: 0.7598, Train Acc: 0.7303, Test Loss: 0.7270, Test Acc: 0.7488
