In [1]:
import torch
from torch import nn
import numpy as np
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import torchvision
import torchvision.transforms as transforms
import os
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [33]:
# 1. Patch Embedding
# 0. Patch Embedding Variables
p = 4 # patch
w = 32 # width
h = 32 # height
c = 3 # channel
b = 128 # batch
d = 128 # Dim of patched embeddings
cls = 10 # Class token size
L = 12 # Transformer block size
head_num = 8 # attention heads
n = w//p

In [34]:
# Import Dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = b

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [35]:
# Trainable Linear Projection이 필요
# nn.Module로 구성
class PositionalEmbedding(nn.Module):
    def __init__(self):
        super(PositionalEmbedding, self).__init__()
        self.projection = nn.Linear(p*p*c, d) # These image patch vectors are now encoded using a linear transformation. Fixed size `d`

    def patchify(self,img):
      # Divide to patch
      patched_img = img.view(b,c,h//p,p,w//p,p) # 이미지 1개당 N*N개 패치가 나오고, 패치 하나의 이미지는 P*P*C
      patched_img = patched_img.transpose(3,4)
      patched_img = patched_img.transpose(1,3)
      patched_img = patched_img.transpose(1,2)
      patched_img = patched_img.reshape(b,n*n,p*p*c)
      return patched_img
    def class_emb(self, patch):
      x_class = nn.Parameter(torch.randn(b,1,d)).to(device)
      with_class = torch.cat((x_class, patch), dim = 1)
      # print("with class embedding : ", with_class.shape)
      return with_class

    def position_emb(self, class_patch):
      pos_emb = nn.Parameter(torch.randn(b,n*n+1,d)).to(device)
      with_class_pos = class_patch + pos_emb # 이게 맞나? 그냥 더하는게?
      # print("with class & positional embedding : ", with_class_pos.shape)
      return with_class_pos

    def forward(self, x):
      patched_ = self.patchify(x)
      patched_ = self.projection(patched_)
      patched_ = self.class_emb(patched_)
      patched_ = self.position_emb(patched_)
      return patched_

In [36]:
# Transformer


class Attention(nn.Module):
  def __init__(self):
    super(Attention, self).__init__()
    self.w_q = nn.Parameter(torch.randn(d, n*n+1))
    self.w_k = nn.Parameter(torch.randn(d, n*n+1))
    self.w_v = nn.Parameter(torch.randn(d, n*n+1))

  def forward(self, x):
    # W_q,W_k,W_v 를 정의
    q = x @ self.w_q
    k = x @ self.w_k
    v = x @ self.w_v
    # QK^T를 만들기
    qk_T = q @ k.T
    # k의 차원 : D (Latent vector)
    qk_T = qk_T / d
    soft_ = nn.Softmax(dim = 0)
    attention_ = soft_(qk_T)
    # print(attention_.shape, v.shape)
    ret = attention_ @ v
    return ret

class MultiHeadAttention(nn.Module):
  def __init__(self):
    super(MultiHeadAttention, self).__init__()
    self.attn = Attention()
    self.w_o = nn.Parameter(torch.randn(head_num*(n*n+1), d))
  def forward(self,x):
    # Head의 Concat이 필요
    head_list = []
    for h in range(head_num):

      x_h = self.attn(x)
      head_list.append(x_h)
    ret = torch.cat(head_list, dim =1)
    ret = ret @ self.w_o
    return ret

class VisionTransformerBlock(nn.Module):
  def __init__(self):
    super(VisionTransformerBlock, self).__init__()
    self.msa = MultiHeadAttention()
    self.bn1 = nn.LayerNorm(d) # Size of BatchNorm1d is the input's size
    self.bn2 = nn.LayerNorm(d) # Size of BatchNorm1d is the input's size
    self.mlp = nn.Linear(d,d)
  def forward(self, x):
    # Batch Norm 1d
    x = self.bn1(x)
    # Multi-head Attention (Done)
    x_attn = self.msa(x)
    # print(x_attn.shape, x.shape)
    # Residual connections
    x_attn = x_attn + x
    # Norm
    out = self.bn2(x_attn)
    # MLP
    out = self.mlp(x_attn)
    # Concat
    out = out + x_attn
    # print(out.shape)
    return out

class VisionTransformer(nn.Module):
  def __init__(self):
    super(VisionTransformer, self).__init__()
    self.vit = nn.ModuleList([VisionTransformerBlock()
                              for _ in range(L)])
    # Full Connected Layer
    self.mlp = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, cls)
        )
    self.pe = PositionalEmbedding()
  def forward(self,x):
    pe_out = self.pe(x)
    # Seqeuence L 반복
    # ViT가 계속 업데이트 되야되는데 ..
    outputs = []
    for d in pe_out:
      # print(d.shape)
      for layer in self.vit:
        d = layer(d)
      # print(d.shape)
      outputs.append(d)
      # 각 이미지에 대한 output을 의미해야되는데
      # label이 0,1이 아니라 1~10으로 구성이 되어 있다.
    outputs = torch.stack(outputs,dim = 0).to(device)
    out = self.mlp(outputs[:,0]) # 처음 토큰이 class token 이니까
    return out.to(device)

In [37]:
def accuracy(dataloader, model):
    correct = 0
    total = 0
    running_loss = 0
    n = len(dataloader)
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        model.eval()
        for data in dataloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs, _ = model(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()

        loss_result = running_loss / n

    acc = 100 * correct / total
    model.train()
    return acc, loss_result

In [38]:
import torch.optim as optim
from tqdm import tqdm

ViT = VisionTransformer()
ViT.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(ViT.parameters(), lr=0.001, momentum=0.9)
# 지금은 patch를 1D로 만들고, cls, pos 를 붙임
# patch , cls, pos를 붙인 다음에

ViT.train()
# print(ViT)
n = len(train_loader)
for epoch in range(10):
  running_loss = 0
  for img, label in tqdm(train_loader):
    img = img.to(device)
    label = label.to(device)
    out = ViT(img)
    label_f32 = label.type('torch.LongTensor').to(device)
    # print(out.dtype, label_f32.dtype)
    loss = criterion(out, label_f32)

    # loss
    optimizer.zero_grad()
    loss.backward() #retain_graph=True
    optimizer.step()

  train_loss = running_loss / n
  val_acc, val_loss = accuracy(test_loader, ViT)
  # if epoch % 5 == 0:
  print('[%d] train loss: %.3f, validation loss: %.3f, validation acc %.2f %%' % (epoch, train_loss, val_loss, val_acc))
  torch.save(ViT, '/content/drive/MyDrive/Vision-Transformer/model_new.pth')
  # ViT.eval()
  # test_loss = 0.0
  # correct = 0

  # with torch.no_grad():
  #     for images, labels in test_loader:

  #         images = images.to(device)
  #         labels = labels.to(device)

  #         outputs = ViT(images)
  #         predicted = torch.max(outputs, 1)[1]
  #         loss = criterion(outputs, labels)

  #         test_loss += loss.item()
  #         correct += (labels == predicted).sum()
  # # Test Loss Logging
  # print(
  #     f"epoch {epoch+1} - test loss: {test_loss / len(test_loader):.4f}"
  # )

  0%|          | 0/5000 [00:00<?, ?it/s]


RuntimeError: shape '[10, 25000000, 48]' is invalid for input of size 30720

In [12]:
model = torch.load('/content/drive/MyDrive/Vision-Transformer/model.pth')
model.eval()
test_loss = 0.0
correct = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        print(outputs)
        print(labels)
        break

  self.pid = os.fork()
  self.pid = os.fork()


tensor([[ 0.3103,  0.1640, -0.4686, -0.0148, -0.2606, -0.3470, -0.6482, -0.1207,
          0.2620,  0.3200],
        [ 0.0631,  0.6368,  0.1217,  0.0908, -0.1239,  0.2969, -0.2306, -0.2415,
          0.4690,  0.3366],
        [-0.0240, -0.1394, -0.0156,  0.2074,  0.1959,  0.0371,  0.0841,  0.0255,
         -0.1048,  0.1100],
        [ 0.0179, -1.1196,  0.8682,  0.6601,  0.9979,  0.5115,  0.4984,  0.5505,
         -0.2318, -0.9103],
        [ 0.2933,  0.1393, -0.4507, -0.0085, -0.2379, -0.3338, -0.6183, -0.1007,
          0.2344,  0.2979],
        [ 0.1185,  0.8088,  0.0116,  0.0674, -0.2360,  0.2458, -0.3930, -0.3428,
          0.5987,  0.5082],
        [ 0.0695,  0.6647,  0.1017,  0.0847, -0.1439,  0.2872, -0.2572, -0.2616,
          0.4855,  0.3657],
        [-0.3389, -0.2026,  0.6575,  0.4197,  0.8786,  0.8239,  0.9859,  0.4691,
         -0.7461, -0.3810],
        [ 0.2837,  0.1157, -0.4415, -0.0087, -0.2257, -0.3303, -0.5978, -0.0878,
          0.2106,  0.2784],
        [ 0.3209,  