In [1]:
!pip install pytorch-ignite




[notice] A new release of pip is available: 23.3.1 -> 24.3.1
[notice] To update, run: C:\Users\hwnam\AppData\Local\Programs\Python\Python312\python.exe -m pip install --upgrade pip


In [2]:
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [3]:
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DataLoader, random_split, ConcatDataset

In [4]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [5]:
trainval_data = datasets.OxfordIIITPet(root="data", split="trainval", target_types="category", download=True, transform=train_transform)
test_data = datasets.OxfordIIITPet(root="data", split="test", target_types="category", download=True, transform=train_transform)
combined_data = ConcatDataset([trainval_data, test_data])

train_size = int(0.7 * len(combined_data))
val_size = int(0.15 * len(combined_data))
test_size = len(combined_data) - train_size - val_size
train_data, val_data, test_data = random_split(combined_data, [train_size, val_size, test_size])

In [6]:
class LinearNorm(nn.Module):
  def __init__(self, in_features, out_features):
    super(LinearNorm, self).__init__()
    self.linear = nn.Linear(in_features, out_features, bias=False)
    self.norm = nn.BatchNorm1d(out_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

  def forward(self, x):
    if x.dim() == 3:
      B, N, C = x.shape
      x = x.reshape(B * N, C)
      x = self.linear(x)
      x = self.norm(x)
      x = x.reshape(B, N, -1)
    else:
      x = self.linear(x)
      x = self.norm(x)

    return x

In [7]:
class attention_bias(nn.Module):
    def __init__ (self, num_heads, size):
        super(attention_bias, self).__init__()
        self.num_heads = num_heads
        self.size = size
        self.bias = nn.Parameter(torch.zeros(num_heads, size, size, size, size))
        
    def forward(self, x):
        bias = self.bias.clone()
        for i in range(self.size):
            for j in range(self.size):
                for i_prime in range(self.size):
                    for j_prime in range(self.size):
                        x_diff = abs(i - i_prime)
                        y_diff = abs(j - j_prime)
                        
                        distance = (x_diff ** 2 + y_diff ** 2) ** 0.5
                        bias[:, i, j, i_prime, j_prime] = -distance # distance가 클수록 작은 값을 가짐
        
        bias = bias.view(self.num_heads, self.size ** 2, self.size ** 2)
        return bias

In [63]:
class Attention(nn.Module):
  def __init__(self, dim, head_dim, patch_num):
    super(Attention, self).__init__()
    self.heads = dim // head_dim
    self.head_dim = head_dim
    self.qkv = LinearNorm(dim, dim * 3)
    self.bias = attention_bias(self.heads, patch_num)
    self.proj = nn.Sequential(
      nn.Hardswish(),
      LinearNorm(dim, dim)
    )
    
  def forward(self, x):
    B, N, C = x.shape
    H = int(N ** 0.5)
    W = int(N ** 0.5)
    
    qkv = self.qkv(x)
    qkv = qkv.view(B, N, 3, self.heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    
    #bias
    b = self.bias(x)
    
    similarity_score = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + b.unsqueeze(0)
    similarity_score = similarity_score.softmax(dim=-1)
    out = similarity_score @ v
    out = self.proj(out.transpose(1, 2).flatten(2))
    out = out.reshape(B, N, C)
    return out

In [64]:
class ConvNorm(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
    super(ConvNorm, self).__init__()
    self.linear = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    self.norm = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

  def forward(self, x):
    x = self.linear(x)
    x = self.norm(x)

    return x

In [65]:
class NormLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super(NormLinear, self).__init__()
    self.bn = nn.BatchNorm1d(in_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    self.drop = nn.Dropout(p = 0.0, inplace = False)
    self.linear = nn.Linear(in_features, out_features, bias=False)

  def forward(self, x):
    x = self.bn(x)
    x = self.drop(x)
    x = self.linear(x)

    return x

In [110]:
class Downsample(nn.Module):
  def __init__(self):
    super(Downsample, self).__init__()
    self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
  
  def forward(self,x):
    x = x.reshape(x.shape[0], x.shape[1], int(x.shape[2] ** 0.5), int(x.shape[2] ** 0.5))
    down_x = self.pool(x)
    down_x = down_x.flatten(2)
    return down_x
  
class AttentionDownsample(nn.Module):
  def __init__(self, dim, heads, patch_num, out_dim):
    super(AttentionDownsample, self).__init__()
    self.heads = heads
    self.head_dim = dim // heads
    self.kv = LinearNorm(dim, dim * 5)
    self.downsample = Downsample()
    self.q = LinearNorm(dim, dim)
    self.proj = nn.Sequential(
        nn.Hardswish(),
        LinearNorm(dim * 4, out_dim)
    )
    self.bias = attention_bias(self.heads, patch_num)
  def forward(self, x):
    B, N, C = x.shape
    print(x.shape)
    H = int(N ** 0.5)
    W = int(N ** 0.5)
    
    kv = self.kv(x)
    k, v = kv.split([C, C * 4], dim=-1)
    q = self.downsample(self.q(x))
    q_t = q.transpose(1, 2)
    similarity_score = (torch.bmm(q_t, k)) / (self.head_dim ** 0.5)
    b = self.bias(x)
    # b = b.unsqueeze(0)  # 배치 차원을 추가
    #similarity_score = similarity_score + b
    similarity_score = similarity_score.softmax(dim=-1)
    print(similarity_score.shape)
    print(v.shape)
    attention_out = attention_out.reshape(B, -1, H // 2, W // 2)
    proj = self.proj(x)

    return proj

In [111]:
class LeViTMlp(nn.Module):
  def __init__(self ,in_features, hidden_features):
    super(LeViTMlp, self).__init__()
    self.linear1 = LinearNorm(in_features, hidden_features)
    self.act = nn.Hardswish()
    self.drop = nn.Dropout(p = 0.0)
    self.linear2 = LinearNorm(hidden_features, in_features)

  def forward(self, x):
    x = self.linear1(x)
    x = self.act(x)
    x = self.drop(x)
    x = self.linear2(x)

    return x

In [112]:
class LeViTBlock(nn.Module):
  def __init__(self, in_features, num_heads, patch_num):
    super(LeViTBlock, self).__init__()
    self.attn = Attention(in_features, num_heads, patch_num)
    self.drop_path1 = nn.Identity()
    self.mlp = LeViTMlp(in_features, in_features * 2)
    self.drop_path2 = nn.Identity()

  def forward(self, x):
    x = x + self.drop_path1(self.attn(x))
    x = x + self.drop_path2(self.mlp(x))

    return x

In [113]:
class LeViT_swin(nn.Module):
  def __init__(self):
    super(LeViT_swin, self).__init__()
    self.stem = nn.Sequential(
        ConvNorm(3, 32, kernel_size=3, stride=2, padding=1),
        nn.Hardswish(),
        ConvNorm(32, 64, kernel_size=3, stride=2, padding=1),
        nn.Hardswish(),
        ConvNorm(64, 128, kernel_size=3, stride=2, padding=1),
        nn.Hardswish(),
        ConvNorm(128, 256, kernel_size=3, stride=2, padding=1)
    )
    self.LeViTBlock0 = LeViTBlock(256, 64, 14)
    self.attentiondownsample0 = AttentionDownsample(256, 8, 14, 384)
    self.LeViTMlp0 = LeViTMlp(384, 384)
    self.drop_path = nn.Identity()
    
    self.LeViTBlock1 = LeViTBlock(384, 64, 7)
    self.attentiondownsample1 = AttentionDownsample(384, 12, 7, 512)
    self.LeViTMlp1 = LeViTMlp(512, 512)
    self.drop_path = nn.Identity()
    
    self.LeViTBlock2 = LeViTBlock(512, 64, 4)
    self.head = NormLinear(512, 37)
    self.head_dist = NormLinear(512, 37)


  def forward(self, x):
    x = self.stem(x)
    B, C, H, W = x.shape
    x = x.view(B, C, -1).transpose(1, 2)
    x = self.LeViTBlock0(x)
    x = self.LeViTBlock0(x)
    x = self.LeViTBlock0(x)
    x = self.LeViTBlock0(x)
    x = self.attentiondownsample0(x)
    x = x + self.drop_path(self.LeViTMlp0(x))
    
    x = self.LeViTBlock1(x)
    x = self.LeViTBlock1(x)
    x = self.LeViTBlock1(x)
    x = self.LeViTBlock1(x)
    x = self.attentiondownsample1(x)
    x = x + self.drop_path(self.LeViTMlp1(x))
    
    x = self.LeViTBlock2(x)
    x = self.LeViTBlock2(x)
    x = self.LeViTBlock2(x)
    x = self.LeViTBlock2(x)
    
    x = x.mean(1)
    result = self.head(x)
    
    return result
    

In [114]:
model = LeViT_swin()
print(model)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

batch_size = 32
learning_rate = 0.001
num_epochs = 10

LeViT_swin(
  (stem): Sequential(
    (0): ConvNorm(
      (linear): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Hardswish()
    (2): ConvNorm(
      (linear): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): Hardswish()
    (4): ConvNorm(
      (linear): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): Hardswish()
    (6): ConvNorm(
      (linear): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (LeViTBlock0): LeViTBlock(
    (attn): Attention(
      (qkv): LinearNorm(
        (linea

In [115]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")

Train set size: 5144
Validation set size: 1102
Test set size: 1103


In [116]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [117]:
from tqdm import tqdm

In [118]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    accuracy = 100 * correct / total
    print(f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {accuracy:.2f}%")

In [119]:
def evaluate(model, data_loader, criterion, device, phase="Validation"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"{phase}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"{phase} Loss: {epoch_loss:.4f}, {phase} Accuracy: {accuracy:.2f}%")

In [120]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    train(model, train_loader, criterion, optimizer, device)
    evaluate(model, val_loader, criterion, device, phase="Validation")


Epoch 1/10


Training:   0%|          | 0/161 [00:00<?, ?it/s]

torch.Size([32, 196, 256])


Training:   0%|          | 0/161 [00:06<?, ?it/s]

torch.Size([32, 64, 256])
torch.Size([32, 196, 1024])





UnboundLocalError: cannot access local variable 'attention_out' where it is not associated with a value