#### [ML] ViT(20.10); Vision Transformer 코드 구현 및 설명 with pytorch



https://kimbg.tistory.com/31

#### Another Implementation

In [1]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from torchsummary import summary

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 96, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', 
                 channels = 3, dim_head = 96, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)



In [2]:
#import torch
#import torch.nn.functional as F
#import matplotlib.pyplot as plt

#from torch import nn
#from torch import Tensor
#from PIL import Image
#from torchvision.transforms import Compose, Resize, ToTensor
#from einops import rearrange, reduce, repeat
#from einops.layers.torch import Rearrange, Reduce

from torchvision import transforms, datasets

img_size = 224

# Define image size
image_size = (img_size, img_size)  # Replace with your desired image dimensions

# Create data augmentation transforms
data_augmentation = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor (CHW format)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize pixel values (common practice)
    transforms.Resize(image_size),  # Resize image to specified dimensions
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=(-15, 15)),  # Random rotation with range -15 to 15 degrees
    #transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.75, 1.3333))  # Random resized crop
])

#BATCH_SIZE = 512
BATCH_SIZE = 256
train_dataset = datasets.CIFAR100(root="./data/",
                                 train=True,
                                 download=True,
                                 transform=data_augmentation)

test_dataset = datasets.CIFAR100(root="./data/",
                                train=False,
                                download=True,
                                transform=data_augmentation)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

print(train_loader.dataset)

Files already downloaded and verified
Files already downloaded and verified
Dataset CIFAR100
    Number of datapoints: 50000
    Root location: ./data/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-15.0, 15.0], interpolation=nearest, expand=False, fill=0)
           )


In [3]:
model = ViT(
    image_size = 224,
    patch_size = 16,
    #num_classes = 1000,
    num_classes = 100, #cifar100
    #dim = 768,
    dim = 64,
    depth = 12,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

#summary(model, (3,224,224), device='cpu')

In [None]:
#import torch
#from torch import nn
from tqdm.auto import tqdm
from torch.optim import AdamW

# Assuming your model is defined in `model`

num_epochs = 10

weight_decay = 0.001
learning_rate = 0.001


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(DEVICE)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Define loss function (assuming model outputs logits)
criterion = nn.CrossEntropyLoss()  # For logits

def top_k_accuracy_score(y_true, y_pred, k):
  """
  Calculates top-k accuracy score.

  Args:
      y_true: Ground truth labels (one-hot encoded or integer).
      y_pred: Predicted probabilities (2D array).
      k: The value of k for top-k accuracy.

  Returns:
      Top-k accuracy score.
  """
  correct = 0
  for y_t, y_p in zip(y_true, y_pred):
    # Get top k predictions (indices with highest scores)
    top_k_indices = y_p.argsort()[-k:]

    # Check if true label is in top k predictions
    if y_t in top_k_indices:
      correct += 1

  #accuracy = correct / len(y_true)
  return correct


def train(model, train_loader, optimizer, log_interval):
    model.train()
    for batch_idx, (image, label) in tqdm(enumerate(train_loader)):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            correct = top_k_accuracy_score(label.clone(), output.clone(), k=1) 
            #accuracy = correct / len(label)
            print(
                f"train Epoch: {Epoch} [{batch_idx * len(image)}/{len(train_loader.dataset)}({100. * batch_idx / len(train_loader):.0f}%)] \
                           \tTrain Loss: {loss.item()} \
                           \tTrain Accuracy: {100. * correct / len(label):.2f}%")

def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    top_k_correct = 0 
    with torch.no_grad():
        for image, label in tqdm(test_loader):
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            #print(f"output.shape: {output.shape}") #output.shape: torch.Size([512, 100])
            #print(f"label.shape: {label.shape}") #label.shape: torch.Size([512])
            
            test_loss += criterion(output, label).item()
            #prediction = output.max(1, keepdim=True)[1]
            #print(prediction.shape) # torch.Size([512, 1])
            #correct += prediction.eq(label.view_as(prediction)).sum().item()
            correct += top_k_accuracy_score(label, output, k=1)
            #print(f"correct: {correct}")
            top_k_correct += top_k_accuracy_score(label, output, k=5)
            #print(f"top_k_correct: {top_k_correct}")

    
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    test_top_k_accuracy = 100. * top_k_correct / len(test_loader.dataset)
    return test_loss, test_accuracy, test_top_k_accuracy

for Epoch in range(1, num_epochs + 1):
    train(model, train_loader, optimizer, log_interval=20)
    test_loss, test_accuracy, test_top_k_accuracy = evaluate(model, test_loader)
    print(f"\n[EPOCH: {Epoch}]\tTest Loss: {test_loss:.4f}\tTest Accuracy: {test_accuracy} %\tTest top k Accuracy: {test_top_k_accuracy} % \n")

0it [00:00, ?it/s]



  0%|          | 0/40 [00:00<?, ?it/s]


[EPOCH: 1]	Test Loss: 0.0157	Test Accuracy: 9.5 %	Test top k Accuracy: 28.77 % 



0it [00:00, ?it/s]

