<a href="https://colab.research.google.com/github/kznfrd/ml_study/blob/main/vision_transfomer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### vision transformerを試してみる
https://recruit.gmo.jp/engineer/jisedai/blog/vision_transformer/

In [1]:
import torchvision
import torchvision.transforms as transforms

import torch

batch_size = 50

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified



In [2]:
pip install vit-pytorch

Collecting vit-pytorch
  Downloading https://files.pythonhosted.org/packages/31/40/56919a1be6b596f30a692d4855f2d7bde8945e95cd4eb1c6588c109d4581/vit_pytorch-0.6.7-py3-none-any.whl
Collecting einops>=0.3
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Installing collected packages: einops, vit-pytorch
Successfully installed einops-0.3.0 vit-pytorch-0.6.7


In [3]:
import torch
from vit_pytorch import ViT
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=256,
    depth=3,
    heads=4,
    mlp_dim=256,
    dropout=0.1,
    emb_dropout=0.1
).to(device)

In [4]:
import torch.optim as optim
from torch import nn

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = 20
for epoch in range(0, epochs):
    epoch_train_loss = 0
    epoch_train_acc = 0
    epoch_test_loss = 0
    epoch_test_acc = 0
    net.train()
    for data in train_loader:
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()/len(train_loader)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        epoch_train_acc += acc/len(train_loader)
    net.eval()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            epoch_test_loss += loss.item()/len(test_loader)
            test_acc = (outputs.argmax(dim=1) == labels).float().mean()
            epoch_test_acc += test_acc/len(test_loader)
    print(f'Epoch {epoch+1} : train acc. {epoch_train_acc:.2f} train loss {epoch_train_loss:.2f}')
    print(f'Epoch {epoch+1} : test acc. {epoch_test_acc:.2f} test loss {epoch_test_loss:.2f}')

Epoch 1 : train acc. 0.17 train loss 2.21
Epoch 1 : test acc. 0.22 test loss 2.09
Epoch 2 : train acc. 0.22 train loss 2.07
Epoch 2 : test acc. 0.28 test loss 1.96
Epoch 3 : train acc. 0.29 train loss 1.94
Epoch 3 : test acc. 0.33 test loss 1.88
Epoch 4 : train acc. 0.32 train loss 1.88
Epoch 4 : test acc. 0.34 test loss 1.83
Epoch 5 : train acc. 0.34 train loss 1.82
Epoch 5 : test acc. 0.36 test loss 1.77
Epoch 6 : train acc. 0.36 train loss 1.78
Epoch 6 : test acc. 0.38 test loss 1.73
Epoch 7 : train acc. 0.37 train loss 1.75
Epoch 7 : test acc. 0.39 test loss 1.70
Epoch 8 : train acc. 0.39 train loss 1.71
Epoch 8 : test acc. 0.41 test loss 1.66
Epoch 9 : train acc. 0.40 train loss 1.67
Epoch 9 : test acc. 0.42 test loss 1.63
Epoch 10 : train acc. 0.41 train loss 1.64
Epoch 10 : test acc. 0.44 test loss 1.59
Epoch 11 : train acc. 0.42 train loss 1.61
Epoch 11 : test acc. 0.44 test loss 1.56
Epoch 12 : train acc. 0.43 train loss 1.58
Epoch 12 : test acc. 0.46 test loss 1.53
Epoch 13 :

```
さて、ここまで進めておいてネタバラシという理由ではないですが、実は ViT の学習には注意が必要です。ViTでは、まずJFT-300Mという3億枚のデータセット(非公開のようです)を元に事前学習し、その後に各種タスク別に fine-tuning しており、論文によると、膨大なデータセットでの事前学習が個別のタスクへの高い性能を発揮すために必要だと書かれています。事前学習済みのデータが公開されていますので(https://github.com/google-research/vision_transformer)、これを元に個別のタスク向けにチューニングするのが良いのだと思います(JFT-300Mで事前学習したものは未公開のようです)。ネットでViTを検索すると精度が悪いという記事が見つかりますが、これが原因かと思われます。
```