<a href="https://colab.research.google.com/github/jeonggunlee/Vision-Transformer-Study/blob/main/VisionTransformer_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer with Simples MNIST dataset!

최근 주목 받고 있는 Vision Transformer를 효과적으로 이해하기 위해서 Simple 이미지 데이터셋인 MNIST를 이용하여 Vision Transformer의 동작을 살펴봄.

2021-02-04

In [None]:
import torch.nn as nn
import torch
import torchvision
import torch.optim as optim

# MNIST data - raining set / test set. Normalize

torch.manual_seed(42)

# MNIST 이미지 데이터 셋 다운로드
DOWNLOAD_PATH = '/data/mnist'

# 학습시 배치 사이즈 
BATCH_SIZE_TRAIN = 100
# 검증시 배치 사이즈
BATCH_SIZE_TEST = 1000

# MNIST 데이터셋은 "28x28" 사이즈의 손글씨 데이터셋
#

transform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))])

train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True)

test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE_TEST, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw
Processing...
Done!





  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
# [einops] stands for Einstein-Inspired Notation for operations
# 텐서 연산을 보다 효과적으로 구성하기 위한 패키지
!pip install einops

Collecting einops
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.3.0


## Vision Transformer Overal Model View

The vision transformer does not use any CNN style filters for detecting any features from images. It just uses *Self-Attention* mechanism with queries, keys, values derived from input patches (in sequence).


This is the most interesting point of Vision Transformer !

![Vit](https://github.com/jeonggunlee/Vision-Transformer-Study/blob/main/image1.gif?raw=1)

Ref: https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html

In [None]:
import torch
import torch.nn.functional as F

from torch import nn
from einops import rearrange  # 이미지 구조 변경에 매우 용이한 유틸리티

# Vision Transformer (ViT)의 오리지널 소스 코드
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)  # Note: Activation 없음
        )

    def forward(self, x):
        return self.net(x)




## Multi-head Attention Module
 With Query, Key, Value !

In [None]:
class Attention(nn.Module):        # Attention 모듈
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads         # head의 수
        self.scale = dim ** -0.5   # scaled dot-product에 사용될 scale factor: Sqrt(dim)

        ## Query, Key, Value에 대해서 prject을 수행
        ## dim --> dim * 3 => projected Q, projected K, project V
        ## Self attention 구성을 위해서 하나의 X 값을 입력 받아, Qx, Kx, Vx를 구성한다
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        # print("X shape", x.shape)
        # X shape torch.Size([100, 17, 64]) : 패치(16) + class(1) 등 총 17개의 64차원으로 embeding된 데이터가 들어옮

        qkv = self.to_qkv(x)                 # ## Self attention 구성을 위해서 하나의 X 값을 입력 받아, Qx, Kx, Vx를 구성한다
        # print("qkv shape", qkv.shape)
        # qkv shape torch.Size([100, 17, 192]) : 64 x 3 -> 192
        # 192 = qkv * h * d = qkv * 8 * d ==> qkv * d = 24
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)  # b : # batches
                                                                              # n : # of sequence = # of patches + 1
                                                                              # qkv: query, key, value
                                                                              # h : # of heads
                                                                              # d : # of dimension
        #q shape torch.Size([100, 8, 17, 8])  batch, head, seqlength, dimension
        #k shape torch.Size([100, 8, 17, 8])
        #v shape torch.Size([100, 8, 17, 8])

        # Scaled Dot Product 계산: Self-Attention !!!
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        # print("dots shape", dots.shape)
        # dots shape torch.Size([100, 8, 17, 17]) --> [17, 17] is an attention map

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        # Softmax를 통해서 Attention 계산
        attn = dots.softmax(dim=-1)
        # print("Attention shape", attn.shape)
        # Attention shape torch.Size([100, 8, 17, 17])

        # Attention * Value for each head for each batch
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        # print("Value * Attention shape", out.shape)
        # Value * Attention shape torch.Size([100, 8, 17, 8])

        out = rearrange(out, 'b h n d -> b n (h d)')
        
        out =  self.to_out(out)  # nn.Linear(dim, dim)
        return out



By printing out ```"attn"```, you can check how attention maps are derived from input patch sequences.

What does the attention betweeen image patches actually mean ? It maybe a very fundamental question for understanding the transformer's ability of classifying image classes.

## Transformer (Encoder Part Only)= Multi-head Attention + FeedForward Network

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x



## Visual Transformer

![Transformer](https://nlpinkorean.github.io/images/transformer/transformer_resideual_layer_norm_2.png)

Ref: https://nlpinkorean.github.io/illustrated-transformer/

In [None]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        # patch의 수 x축 패치 수와 y축 패치 수를 곱하고 다시 채널수를 곱함. ==> Sequence Length (n) = 17
        # MNIST의 경우 채널이 하나이기 때문에 channels = 1
        patch_dim = channels * patch_size ** 2  # 7*7 = 49. 추후 임베딩을 통해서 64로 변경

        # 패치 하나의 길이 in pixel
        self.patch_size = patch_size

        # 위치 임베딩을 위한 학습 파라미터 생성
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.patch_to_embedding = nn.Linear(patch_dim, dim)  # 49 --> 64

        # 클래스 토큰 파라미터 생성
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, img, mask=None):
        p = self.patch_size

        # rearrange
        # batch channel (head p1) (w p2) -> batch (h w) (p1 p2 c)
        # h : # of patches in a vertical view
        # w : # of patches in a horizontal view
        # In MNIST: h=4, w=4, p1=7, p2=7, c=1
        #print("Input X", img.shape)
        # Input X torch.Size([100, 1, 28, 28])
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        #print("Rearrange X", x.shape)
        # Rearrange X torch.Size([100, 16, 49])
        x = self.patch_to_embedding(x)     # (h w) --> dim vector
        # b (h w) (p1 p2 c) --> b (h w) dim
        #print("Embedding X", x.shape)
        # Embedding X torch.Size([100, 16, 64])

        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.transformer(x, mask)

        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)

## Train and Validation

In [None]:
# training / evaluate function

def train_epoch(model, optimizer, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    model.train()

    # data ~ torch.Size([100, 1, 28, 28])
    # batch size = 100
    for i, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        output = F.log_softmax(model(data), dim=1)             
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())

def evaluate(model, data_loader, loss_history):
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    
    # avg test loss / avg accuracy
    with torch.no_grad():
        for data, target in data_loader:
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)
            
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')

In [None]:
import time

N_EPOCHS = 1  # 25 default. 1 just for debugging

start_time = time.time()

# Vision Transformer !
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
            dim=64, depth=6, heads=8, mlp_dim=128)
optimizer = optim.Adam(model.parameters(), lr=0.003)

train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    train_epoch(model, optimizer, train_loader, train_loss_history)
    evaluate(model, test_loader, test_loss_history)

print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')

Epoch: 1

Average test loss: 0.1301  Accuracy: 9582/10000 (95.82%)

Execution time: 88.26 seconds
