In [86]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
from torchsummary import summary

import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [87]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8):
        super().__init__()
        self.dim = dim
        self.dim_heads = dim // heads
        self.norm = nn.LayerNorm(dim)
        self.to_qkv = nn.Linear(dim, dim * 3)
        self.MHA = nn.MultiheadAttention(dim, heads, batch_first=True)       #dim means input sequence's dim

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)   #example: https://pytorch.org/docs/stable/generated/torch.chunk.html
        q, k, v = [token for token in qkv]
        result = self.MHA(q,k,v, need_weights=True)[0]
        return result


In [88]:
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        layers = []
        layers.append(nn.LayerNorm(dim))
        layers.append(nn.Linear(dim, mlp_dim))
        layers.append(nn.GELU())
        layers.append(nn.Linear(mlp_dim, dim))
        self.net = nn.Sequential(*layers)

    def forward(self,x):
        return self.net(x)

In [157]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):       #in paper head_dim = dim * 4
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([(Attention(dim, heads)),(FeedForward(dim, mlp_dim))]))         
        self.layers = nn.Sequential()

    def forward(self,x):
        for attn, ffn in self.layers:           #이전의 layers를 list에서 iterating하는 방식은 cuda, cpu device 오류남
            x = attn(x) + x
            x = ffn(x) + x
        return self.norm(x)
        

In [140]:
class VisionTransformer(nn.Module):

    def __init__(self, batch_size, num_classes, dim, depth, heads, mlp_dim, output_dim, img_dim = [3,224,224], patch_dim = [3,56,56], dim_head = 64):
        super().__init__()
        image_h = img_dim[1]
        image_w = img_dim[2]
        patch_h = patch_dim[1]
        patch_w = patch_dim[2]

        n_patches = (image_h // patch_h) * (image_w // patch_w)
        embedding_dim = img_dim[0] * patch_h * patch_w

        self.patch_dim = patch_dim
        self.img_dim = img_dim
        self.batch_size = batch_size
        self.n_patches = n_patches
        self.embedding_dim = embedding_dim

        #so we flatten the patches and map to D dimensions with a trainable linear projection (Eq. 1).
        self.projection = nn.Sequential(     
            nn.LayerNorm(embedding_dim),                            #layernorm에 대한 언급은 못찾겠음
            nn.Linear(embedding_dim, dim),
            nn.LayerNorm(dim)
        )
        self.cls_token =nn.Parameter(torch.randn(1, dim))
        self.pos_embedding =nn.Parameter(torch.randn(1, n_patches+1, dim))

        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.classification_head = nn.Linear(dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)


    def forward(self, img):
        channels=img.shape[1]

        x = img.unfold(2, self.patch_dim[1], self.patch_dim[2]).unfold(3, self.patch_dim[1], self.patch_dim[2])
        x = x.contiguous().view(self.batch_size, channels, self.n_patches, self.patch_dim[1], self.patch_dim[2])
        patches = x.permute(0, 2, 3, 4, 1)
        x = patches.contiguous().view(self.batch_size, self.n_patches, self.embedding_dim)
        x = self.projection(x)

        cls_tokens = self.cls_token.repeat(self.batch_size, 1, 1)

        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embedding[:, :(self.n_patches+1)]

        x = self.transformer(x)
        x = x[:,0]

        x = self.classification_head(x)
        x = self.norm(x)                    #is this order right?
        #많은 구현체에서 norm순서나 유무 vatiation이 많았다

        return x

In [141]:
model = VisionTransformer(batch_size=16, num_classes=10, dim=64, depth=8, heads=64, mlp_dim=256, output_dim=1000)

In [142]:
x = torch.randn([16,3,224,224])    #(b,c,h,w)
y = model(x)

In [163]:
F.softmax(outputs)

  F.softmax(outputs)


tensor([[0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0.0438, 0.2338, 0.1396, 0.3822, 0.0206, 0.0663, 0.0175, 0.0277, 0.0306,
         0.0379],
        [0

In [143]:
print(x.shape)
print(y.shape)

torch.Size([16, 3, 224, 224])
torch.Size([16, 1000])


# Train test

## define dataset, dataloader

In [144]:
BATCH_SIZE = 32

In [145]:
std10_train = datasets.STL10(
    "../Datasets/STL10_PyTorch/",
    split = "train",
    transform=transforms.ToTensor(),
    download=True
)

Files already downloaded and verified


In [146]:
from sklearn.model_selection import train_test_split
import copy

In [147]:
train_indices, val_indices = train_test_split(range(len(std10_train)), test_size=0.1)
std10_val = copy.deepcopy(std10_train)
std10_train = Subset(std10_train, train_indices)
std10_val = Subset(std10_val, val_indices)

In [148]:
print(len(train_indices), len(val_indices))
print(len(std10_train), len(std10_val))

4500 500
4500 500


In [149]:
train_loader = DataLoader(
    std10_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    std10_val,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

## train

In [150]:
epochs=100

In [151]:
std10_train[0][0].shape

torch.Size([3, 96, 96])

In [170]:
model = VisionTransformer(batch_size=BATCH_SIZE, num_classes=10, dim=64, depth=8, heads=64, mlp_dim=256, output_dim=10, img_dim=[3,96,96], patch_dim=[3,24,24]).to(device)
# model = model.to(device)

In [171]:
optimizer = optim.Adam(model.parameters())      #weight_decay 넣으면 학습 진행 X

In [172]:
criterion = nn.CrossEntropyLoss()

In [173]:
for epoch in range(epochs):
    loss = 0
    loss_epoch = 0
    model.train()
    acc = 0.0
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
    loss_epoch /= len(train_loader)
    print(loss_epoch)

2.5313214795930046
2.463254565852029
2.41932840858187
2.3775292481694903
2.3482827356883456
2.3274914503097532
2.313667665209089
2.308606333392007
2.3061716999326434
2.3056612168039594
2.3046116045543124
2.30450291293008
2.3052499566759383
2.3044195158141
2.30444803408214
2.3042026145117624
2.3035844922065736
2.304558348655701
2.30431923866272
2.304315437589373
2.3040058987481253
2.3046348316328866
2.3043256878852842
2.3042466248784748
2.303660695893424
2.304019810472216
2.3040760108402796
2.3038133723395213
2.303826837880271
2.3039268919399807
2.304162836074829
2.30404965366636
2.304058085169111
2.303848930767604
2.3039709448814394
2.3039488213402883
2.3036889723369054
2.3037174105644227
2.3036822523389544
2.303731768471854
2.303824932234628
2.303983645779746
2.3034564818654744
2.303376683167049
2.303542254652296
2.303319081238338
2.303509041241237
2.303492774282183
2.3035590750830512
2.3036773800849915
2.3033979075295585
2.303585549763271
2.303572876112802
2.303557046822139
2.3033667

KeyboardInterrupt: 