# VIsion Transformer (ViT)

- image 를 word sequence 처럼 취급 -> input image 를 patch (word -> token 처럼) 로 분할 후 transformer에 입력

- 구조 : Imaqge Patching -> Embedding -> Transformer (Encoder Only) 

    Image Patching : image를 작은 patch로 분할하는 작업

    Embedding : patch를 Flatten -> patch embedding vector 로 변환 -> class token embedding 추가 -> positional embedding 추가

    Transformer : multi-head self-attention 으로 patch들의 관계 학습 -> class token vector를 통해 classification

In [37]:
import torch
from torchvision.transforms import Compose, ToTensor,Normalize
from torchvision.datasets import CIFAR10
from torch.nn import Module, Dropout, Linear, CrossEntropyLoss, init, utils, LayerNorm, ModuleList, Conv2d, Parameter
from torch.optim import AdamW
from torch.utils.data import DataLoader

from time import time
import math

In [38]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else: 
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


In [39]:
transform = Compose([
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])

train_data = CIFAR10('data_cifar10', train=True, transform=transform)
valid_data = CIFAR10('data_cifar10', train=False, transform=transform)
test_data = CIFAR10('data_cifar10', train=False, transform=transform)

batch_size = 256
train_iterator = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_iterator = DataLoader(valid_data, batch_size=batch_size)

In [40]:
class PositionalEncoding(Module):
    def __init__(self, model_dim, max_len=5000):
        super().__init__()
        
        position_encoding_matrics = torch.zeros(max_len, model_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
        position_encoding_matrics[:, 0::2] = torch.sin(position * div_term)
        position_encoding_matrics[:, 1::2] = torch.cos(position * div_term)
        position_encoding_matrics = position_encoding_matrics.unsqueeze(0)
        self.register_buffer('position_encoding_matrics', position_encoding_matrics)
        
    def forward(self, x):
        # input sequence의 길이로 잘라서 리턴 (ViT에서는 [CLS] + N_patches)
        return self.position_encoding_matrics[:, :x.size(1)]

In [41]:
class MultiHeadAttention(Module):
    def __init__(self, model_dim, n_heads, dropout):
        super().__init__()

        self.model_dim = model_dim
        self.n_heads = n_heads
        self.head_dim = model_dim // n_heads 
        self.fc_q = Linear(model_dim, model_dim)
        self.fc_k = Linear(model_dim, model_dim)
        self.fc_v = Linear(model_dim, model_dim)
        self.fc_o = Linear(model_dim, model_dim)
        self.dropout = Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value):
        batch_size = query.shape[0]
        
        q = self.fc_q(query)
        k = self.fc_k(key)
        v = self.fc_v(value)
        
        q = q.reshape(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale
        
        attention = torch.softmax(energy, dim = -1)
        x = torch.matmul(self.dropout(attention), v)
        
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.reshape(batch_size, -1, self.model_dim)
        
        x = self.fc_o(x)
        
        return x, attention

In [42]:
class PositionwiseFeedforward(Module):
    def __init__(self, model_dim, feedforward_dim, dropout):
        super().__init__()
        self.fc_1 = Linear(model_dim, feedforward_dim)
        self.fc_2 = Linear(feedforward_dim, model_dim)
        self.dropout = Dropout(dropout)
        
    def forward(self, x):
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        
        return x

In [43]:
class SublayerConnection_PreNorm(Module):
    def __init__(self, model_dim, dropout):
        super().__init__()
        self.norm = LayerNorm(model_dim)
        self.dropout = Dropout(dropout)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [44]:
class EncoderLayer(Module):
    def __init__(self, model_dim, n_heads, feedforward_dim, dropout):
        super().__init__()
        self.self_attention = MultiHeadAttention(model_dim, n_heads, dropout)
        self.self_attention_sublayer = SublayerConnection_PreNorm(model_dim, dropout) 
        self.feed_forward = PositionwiseFeedforward(model_dim, feedforward_dim, dropout)
        self.feedforward_sublayer = SublayerConnection_PreNorm(model_dim, dropout)
        
    def forward(self, src):
        src = self.self_attention_sublayer(src, lambda x: self.self_attention(x, x, x)[0])
        src = self.feedforward_sublayer(src, self.feed_forward)
        
        return src

In [45]:
class Encoder(Module):
    def __init__(self, model_dim, n_layers, n_heads, feedforward_dim, dropout):
        super().__init__()
        # ViT 에서는 embedding과 positional encoding을 ViT에서 처리
        self.dropout = Dropout(dropout)
        self.layers = ModuleList([
            EncoderLayer(model_dim, n_heads, feedforward_dim, dropout) for _ in range(n_layers)
        ])
        # ViT 는 encoder only -> 마지막 출력이 된다.
        self.norm = LayerNorm(model_dim)
        
    def forward(self, src):
        for layer in self.layers:
            src = layer(src) 
            
        return self.norm(src)

In [46]:
class VisionTransformer(Module):
    def __init__(self, img_size, patch_size, in_channels, num_classes, model_dim, n_layers, n_heads, feedforward_dim, dropout):
        super().__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # patch embedding : image -> patch -> embedding vector
        # patch size와 stride를 똑같이 설정 = 겹치지 않는 patch 생성
        self.patch_embedding = Conv2d(
            in_channels, model_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
        # classification을 위해 label을 token으로 연결 및 초기화
        self.class_token = Parameter(torch.zeros(1, 1, model_dim))
        init.trunc_normal_(self.class_token, std=.02)
        
        # Position Encoding (num_patches + 1 길이)
        self.position_embedding = PositionalEncoding(model_dim, max_len=self.num_patches + 1)
        
        self.dropout = Dropout(dropout)
        
        # transformer의 encoder 사용
        self.transformer_encoder = Encoder(model_dim, n_layers, n_heads, feedforward_dim, dropout)
        
        # classifier 역할을 수행하는 head
        self.classifier_head = Linear(model_dim, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]

        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)

        # class token 추가
        class_tokens = self.class_token.expand(batch_size, -1, -1)
        x = torch.cat((class_tokens, x), dim=1)
        
        # positional encoding 추가
        x = self.dropout(x + self.position_embedding(x))
        
        x = self.transformer_encoder(x)
        
        # classification
        class_output = x[:, 0]
        output = self.classifier_head(class_output)
        
        return output

In [47]:
img_size = 32
# 4x4 패치 -> (32/4)^2 = 64개 패치 시퀀스
patch_size = 4
in_channels = 3
# 0-9 클래스
output_dim = 10
model_dim = 512
n_heads = 8
n_layers = 6
feedforward_dim = model_dim * 4
dropout = 0.2

In [48]:
model = VisionTransformer(img_size, patch_size, in_channels, output_dim, model_dim, n_layers, n_heads, feedforward_dim, dropout).to(device)

In [49]:
# 가중치 초기화
# layer에 따라 다른 가중치 초기화 기법 사용
def init_weights(model):
    # Linear 일 경우 : Truncated Normal Initialization
    if isinstance(model, Linear):
        init.trunc_normal_(model.weight, std=.02)
        if model.bias is not None:
            init.constant_(model.bias, 0)
    # Layer Normalization 일 경우 : weight 는 1 / bias 는 0
    elif isinstance(model, LayerNorm):
        init.constant_(model.bias, 0)
        init.constant_(model.weight, 1.0)

    # Convolution 일 경우 : Kaiming Normal Initialization (카이밍 정규분포 초기화)
    elif isinstance(model, Conv2d):
        init.kaiming_normal_(model.weight, mode='fan_out', nonlinearity='relu')
        if model.bias is not None:
            init.constant_(model.bias, 0)
    

model.apply(init_weights)

VisionTransformer(
  (patch_embedding): Conv2d(3, 512, kernel_size=(4, 4), stride=(4, 4))
  (position_embedding): PositionalEncoding()
  (dropout): Dropout(p=0.2, inplace=False)
  (transformer_encoder): Encoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attention): MultiHeadAttention(
          (fc_q): Linear(in_features=512, out_features=512, bias=True)
          (fc_k): Linear(in_features=512, out_features=512, bias=True)
          (fc_v): Linear(in_features=512, out_features=512, bias=True)
          (fc_o): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (self_attention_sublayer): SublayerConnection_PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (feed_forward): PositionwiseFeedforward(
          (fc_1): Linear(in_features=512, out_feat

In [50]:
class NoamOptimizer:
    def __init__(self, model_dim, warmup_steps, optimizer):
        self.optimizer = optimizer
        self.model_dim = model_dim
        self.warmup_steps = warmup_steps
        self._step = 0
        self._rate = 0.

    def step(self):
        self._step += 1
        rate = self.rate()
        for param in self.optimizer.param_groups:
            param['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        if step is None:
            step = self._step
        
        scale = self.model_dim ** (-0.5)
        
        return scale * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

In [51]:
loss_function = CrossEntropyLoss()

learning_rate = 0.0001
epsilon = 1e-9
betas=(0.9, 0.98)
weight_decay = 0.01 
optimizer = AdamW(model.parameters(), lr=learning_rate, betas=betas, eps=epsilon, weight_decay=weight_decay)

warmup_steps = 2000
optimizer = NoamOptimizer(model_dim, warmup_steps, optimizer)

In [52]:
def train(model, iterator, optimizer, loss_function):
    
    model.train()
    epoch_loss = 0
    
    for images, labels in iterator:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        output = model(images)
        loss = loss_function(output, labels)
        
        loss.backward()
        
        utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [53]:
def evaluate(model, iterator, loss_function):
    
    model.eval()
    epoch_loss = 0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels in iterator:
            images = images.to(device)
            labels = labels.to(device)
            
            output = model(images) 
            loss = loss_function(output, labels)
            epoch_loss += loss.item()

            preds = output.argmax(dim=1)
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)
    
    accuracy = correct_predictions / total_samples
    
    return epoch_loss / len(iterator), accuracy

In [54]:
epochs = 100

total_time = 0
total_acc = list()

best_val_loss = float('inf')
patience = 15
counter = 0
min_delta = 0.001
best_model_path = 'vit_cifar10_best.pt'

for epoch in range(1, epochs + 1):
    start_time = time()
    
    train_loss = train(model, train_iterator, optimizer, loss_function)
    valid_loss, valid_acc = evaluate(model, valid_iterator, loss_function)

    # early stopping
    if valid_loss < best_val_loss - min_delta:
        best_val_loss = valid_loss
        counter = 0
        torch.save(model.state_dict(), best_model_path)
    else:
        counter += 1
        
    if counter >= patience:
        print(f"EarlyStopping (patience : {patience})")
        break

    end_time = time()
    
    epoch_time = end_time - start_time
    total_time += epoch_time

    epoch_mins = int((epoch_time) / 60)
    epoch_secs = int((epoch_time) % 60)

    total_acc.append(valid_acc)

    print(f"epoch: {epoch:3d}/{epochs} \t train loss: {train_loss:4.3f} val loss: {valid_loss:4.3f} \t val acc: {valid_acc*100:5.2f}% \t {epoch_mins}m {epoch_secs}s")


if total_acc:
    avg_total_acc = sum(total_acc) / len(total_acc)
    # early stopping으로 인해 patience 이후의 가중치(=overfitting) 대신 최적의 모델 가중치 사용
    model.load_state_dict(torch.load(best_model_path))
else:
    avg_total_acc = 0.0

print(f"total time : {total_time} \t total acc : {avg_total_acc*100:5.2f}%")

epoch:   1/100 	 train loss: 2.243 val loss: 2.112 	 val acc: 20.38% 	 0m 28s
epoch:   2/100 	 train loss: 1.886 val loss: 1.651 	 val acc: 40.48% 	 0m 28s
epoch:   3/100 	 train loss: 1.646 val loss: 1.538 	 val acc: 44.62% 	 0m 28s
epoch:   4/100 	 train loss: 1.517 val loss: 1.489 	 val acc: 46.96% 	 0m 28s
epoch:   5/100 	 train loss: 1.440 val loss: 1.456 	 val acc: 47.40% 	 0m 28s
epoch:   6/100 	 train loss: 1.389 val loss: 1.377 	 val acc: 50.44% 	 0m 28s
epoch:   7/100 	 train loss: 1.346 val loss: 1.347 	 val acc: 52.40% 	 0m 28s
epoch:   8/100 	 train loss: 1.311 val loss: 1.318 	 val acc: 53.07% 	 0m 28s
epoch:   9/100 	 train loss: 1.302 val loss: 1.329 	 val acc: 52.87% 	 0m 28s
epoch:  10/100 	 train loss: 1.296 val loss: 1.275 	 val acc: 54.67% 	 0m 28s
epoch:  11/100 	 train loss: 1.283 val loss: 1.239 	 val acc: 55.49% 	 0m 28s
epoch:  12/100 	 train loss: 1.243 val loss: 1.244 	 val acc: 55.84% 	 0m 28s
epoch:  13/100 	 train loss: 1.196 val loss: 1.142 	 val acc: 58

In [55]:
test_iterator = DataLoader(test_data, batch_size=batch_size)

test_loss, test_acc = evaluate(model, test_iterator, loss_function)

print(f"evaluation - loss: {test_loss:4.3f} \t acc: {test_acc*100:5.2f}%")

evaluation - loss: 0.907 	 acc: 69.44%


In [56]:
def predict_image(model, image, transform):
    model.eval()
    
    if not torch.is_tensor(image):
        image = transform(image)
    
    # batch dim 추가 (c, h, w) -> (1, c, h, w)
    image = image.unsqueeze(0).to(device)
    
    with torch.no_grad():
        predict = model(image)
    
    predicted_class = predict.argmax(dim=1).item()
    
    return predicted_class


In [57]:
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

for i in range(10):
    test_image_pil, label_index = test_data[i]

    predict = predict_image(model, test_image_pil, transform)

    label_name = class_names[label_index]
    predict_name = class_names[predict]

    print(f"label : {label_name:8} \t predict : {predict_name:8} \t {label_name == predict_name}")


label : cat      	 predict : cat      	 True
label : ship     	 predict : ship     	 True
label : ship     	 predict : ship     	 True
label : plane    	 predict : plane    	 True
label : frog     	 predict : frog     	 True
label : frog     	 predict : frog     	 True
label : car      	 predict : car      	 True
label : frog     	 predict : frog     	 True
label : cat      	 predict : deer     	 False
label : car      	 predict : car      	 True
