In [8]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

## Patches embedding
##### 1. Conv layer를 활용하여 패치 나누기 
(논문의 Hybrid Architecture로 언급)
##### 2. Trainable linear projection을 통해 x의 각 패치를 flatten한 벡터를 D차원으로 변환한 후, 이를 패치 임베딩으로 사용

##### 3. Learnable 임베딩과 패치 임베딩에 learnable position 임베딩을 더함

In [9]:
class PatchEmbedding(nn.Module):
    def __init__(self, input_channels: int=3, patch_size: int=16,
                 embedding_size: int=768, img_size: int=224):
        self.patch_size = patch_size
        super().__init__()
        self.partition = nn.Sequential(
                nn.Conv2d(input_channels, embedding_size,
              kernel_size=patch_size, stride=patch_size),   # torch.Size([1, 768, 14, 14])
                 Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, embedding_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, embedding_size))
    
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.partition(x)   # flatten 과정
        cls_token = repeat(self.cls_token, '() n e -> b n e', b=b )   # cls_token을 배치 크기에 맞춰 복제
        #print(cls_token.shape)
        x = torch.cat([cls_token, x], dim=1)   # ([1, 196 + 1, 768])  dim=1로 설정해서 cls_token을 패치 앞에 추가
        #print(x.shape)
        x += self.positions   # position embedding 더하기

        return x

## Transformer encoder
##### 4. 임베딩을 Transformer encode에 input으로 넣어 마지막 layer에서 class embedding에 대한 output인 image representation을 도출

#### 1. Multi-Head Attention
패치에 대해 self-attention 메커니즘 적용

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size: int=768,
                 num_heads: int=8,
                 dropout: float=0):
        super().__init__()
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(embedding_size, embedding_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(embedding_size, embedding_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        score = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        
        if mask is not None:
            fill_value = torch.finfo(torch.float32),min
            score.mask_fill(~mask, fill_value)
        
        scaling = self.embedding_size ** (1/2)
        attention = F.softmax(score / scaling, dim=-1)
        attention = self.att_drop(attention)
        
        output = torch.einsum('bhal, bhlv -> bhav', attention, values)
        output = rearrange(output, "b h n d -> b n (h d)")
        output = self.projection(output)

        return output

### 2. Transformer Encoder Block
MLP(Feed Forward) 블럭을 만들어주도 Multi Head Attention을 하나로 묶음

In [11]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class FeedForwardBlock(nn.Sequential):
    def __init__(self, embedding_size: int, expansion: int=4, drop_p: float=0.):
        super().__init__(
            nn.Linear(embedding_size, expansion * embedding_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * embedding_size, embedding_size)
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, embedding_size: int=768,
                 drop_p: float=0.,
                 forward_expansion: int=4,
                 forward_drop_p: float=0.,
                 **kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(embedding_size),
                MultiHeadAttention(embedding_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(embedding_size),
                FeedForwardBlock(embedding_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            ))
        )


## Classification Head
##### 5. classification을 위한 MLP Head 부분

In [12]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int=12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, embedding_size: int=768,
                 n_classes: int=1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(embedding_size),
            nn.Linear(embedding_size, n_classes)
        )

class ViT(nn.Sequential):
    def __init__(self,
                 input_channels: int=3,
                 patch_size: int=16,
                 embedding_size: int=768,
                 img_size: int=224,
                 depth: int=12,
                 n_classes: int=1000, **kwargs):
        super().__init__(
            PatchEmbedding(input_channels, patch_size, embedding_size, img_size),
            TransformerEncoder(depth, embedding_size=embedding_size, **kwargs),
            ClassificationHead(embedding_size, n_classes)
        )

In [13]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import torch.optim as optim
from tqdm import tqdm
import os

# Training settings
BATCH_SIZE = 64
EPOCHS = 20
LR = 3e-4
CHECKPOINT_DIR = './checkpoints/'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# CIFAR-10 dataset loading with transforms
transform = Compose([
    Resize((224, 224)),  # ViT expects 224x224 images
    ToTensor(),
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])  # CIFAR-10 normalization
])

# Load CIFAR-10 dataset
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform, download=True)

# Split train into train/validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# DataLoader for train, validation, and test sets
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize model
model = ViT(input_channels=3, patch_size=16, embedding_size=768, img_size=224, depth=12, n_classes=10).to(DEVICE)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Function to save model checkpoint
def save_checkpoint(epoch, model, optimizer, path):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    if not os.path.exists(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)
    torch.save(state, path)

# Training and validation loop
train_losses, val_losses = [], []
train_accs, val_accs = [], []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0
    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    model.eval()
    running_loss = 0.0
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = correct / total
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f"Epoch [{epoch+1}/{EPOCHS}] - "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Save model checkpoint
    save_checkpoint(epoch, model, optimizer, f'{CHECKPOINT_DIR}/vit_epoch_{epoch+1}.pth')
    
    # Adjust learning rate
    scheduler.step()

# Plotting the training loss and accuracy
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

# Test the model on test dataset
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        outputs = model(imgs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_acc = correct / total
print(f"Test Accuracy: {test_acc:.4f}")


Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/625 [00:00<?, ?it/s]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  0%|          | 1/625 [00:11<1:58:48, 11.42s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  0%|          | 2/625 [00:23<2:00:10, 11.57s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  0%|          | 3/625 [00:31<1:46:02, 10.23s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  1%|          | 4/625 [00:40<1:39:32,  9.62s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  1%|          | 5/625 [00:49<1:36:05,  9.30s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  1%|          | 6/625 [00:57<1:33:46,  9.09s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  1%|          | 7/625 [01:06<1:32:14,  8.96s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  1%|▏         | 8/625 [01:15<1:31:16,  8.88s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  1%|▏         | 9/625 [01:24<1:31:01,  8.87s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  2%|▏         | 10/625 [01:32<1:30:38,  8.84s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  2%|▏         | 11/625 [01:41<1:30:13,  8.82s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  2%|▏         | 12/625 [01:50<1:29:56,  8.80s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  2%|▏         | 13/625 [01:59<1:29:33,  8.78s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  2%|▏         | 14/625 [02:07<1:29:25,  8.78s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  2%|▏         | 15/625 [02:16<1:29:46,  8.83s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  3%|▎         | 16/625 [02:25<1:29:48,  8.85s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  3%|▎         | 17/625 [02:34<1:29:21,  8.82s/it]

torch.Size([64, 1, 768])
torch.Size([64, 197, 768])


  3%|▎         | 17/625 [02:43<1:37:13,  9.60s/it]


KeyboardInterrupt: 