#### [ML] ViT(20.10); Vision Transformer 코드 구현 및 설명 with pytorch
https://kimbg.tistory.com/31

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

In [2]:
from torchvision import transforms, datasets
img_size = 224

# Define image size
image_size = (img_size, img_size)  # Replace with your desired image dimensions

# Create data augmentation transforms
data_augmentation = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor (CHW format)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize pixel values (common practice)
    transforms.Resize(image_size),  # Resize image to specified dimensions
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=(-15, 15)),  # Random rotation with range -15 to 15 degrees
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.75, 1.3333))  # Random resized crop
])

BATCH_SIZE = 32
train_dataset = datasets.CIFAR100(root="./data/",
                                 train=True,
                                 download=True,
                                 transform=data_augmentation)

test_dataset = datasets.CIFAR100(root="./data/",
                                train=False,
                                download=True,
                                transform=data_augmentation)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

print(train_loader.dataset)


Files already downloaded and verified
Files already downloaded and verified
Dataset CIFAR100
    Number of datapoints: 50000
    Root location: ./data/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-15.0, 15.0], interpolation=nearest, expand=False, fill=0)
               RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=True)
           )


In [3]:
for (X_train, Y_train) in train_loader:
    print(f"X_train: {X_train.size()} type: {X_train.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break

X_train: torch.Size([32, 3, 224, 224]) type: torch.FloatTensor
Y_train: torch.Size([32]) type: torch.LongTensor


In [4]:
for (x, Y_train) in train_loader:
    print(f"X_train: {x.size()} type: {x.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break

X_train: torch.Size([32, 3, 224, 224]) type: torch.FloatTensor
Y_train: torch.Size([32]) type: torch.LongTensor


In [5]:
## input ##
print('X_train :', X_train.shape)

patch_size = 16 # 16x16 pixel patch
patches = rearrange(X_train, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', 
                    s1=patch_size, s2=patch_size)
print('patches :', patches.shape)

X_train : torch.Size([32, 3, 224, 224])
patches : torch.Size([32, 196, 768])


In [6]:
## input ##
x = torch.randn(32, 3, 224, 224)
print('x :', x.shape)

patch_size = 16 # 16x16 pixel patch
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', 
                    s1=patch_size, s2=patch_size)
print('patches :', patches.shape)

x : torch.Size([32, 3, 224, 224])
patches : torch.Size([32, 196, 768])


In [7]:
t = torch.randn(32, 3, 224, 224)

patch_size = 16
in_channels = 3
emb_size = 768 # channel * patch_size * patch_size

conv_t = nn.Conv2d(in_channels, emb_size, 
              kernel_size=patch_size, stride=patch_size)(t)
conv_t.shape

torch.Size([32, 768, 14, 14])

In [8]:
patches = rearrange(conv_t, 'b e (h) (w) -> b (h w) e')
patches.shape

torch.Size([32, 196, 768])

In [9]:
for (x, Y_train) in train_loader:
    print(f"X_train: {x.size()} type: {x.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break

X_train: torch.Size([32, 3, 224, 224]) type: torch.FloatTensor
Y_train: torch.Size([32]) type: torch.LongTensor


In [10]:
print(x.shape)
print(x.shape[1:])

torch.Size([32, 3, 224, 224])
torch.Size([3, 224, 224])


In [11]:
patch_size = 16
in_channels = 3
emb_size = 768 # channel * patch_size * patch_size

# using a conv layer instead of a linear one -> performance gains
projection = nn.Sequential(
    nn.Conv2d(in_channels, emb_size, 
              kernel_size=patch_size, stride=patch_size), # torch.Size([8, 768, 14, 14])
    Rearrange('b e (h) (w) -> b (h w) e'))

summary(projection, x.shape[1:], device='cpu')
# Conv2d parameter size:
# 590592 = 16 * 16 (patch_size) * 768 (out_channels) * 3 (in_channels) + 768 (bais)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


In [12]:
emb_size = 768
img_size = 224
patch_size = 16

# 이미지를 패치사이즈로 나누고 flatten
projected_x = projection(x)
print('Projected X shape :', projected_x.shape)

# cls_token과 pos encoding Parameter 정의
cls_token = nn.Parameter(torch.randn(1,1, emb_size))
positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
print('Cls Shape :', cls_token.shape, ', Pos Shape :', positions.shape)

Projected X shape : torch.Size([32, 196, 768])
Cls Shape : torch.Size([1, 1, 768]) , Pos Shape : torch.Size([197, 768])


In [13]:
# cls_token을 반복하여 배치사이즈의 크기와 맞춰줌
batch_size = 32
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size)
print('Repeated Cls shape :', cls_tokens.shape)

Repeated Cls shape : torch.Size([32, 1, 768])


In [14]:
# cls_token과 projected_x를 concatenate
# cls_token: [8, 1, 768], projected_x : [8, 196, 768] 
cat_x = torch.cat([cls_tokens, projected_x], dim=1)
cat_x.shape

torch.Size([32, 197, 768])

In [15]:
positions.shape

torch.Size([197, 768])

In [16]:
# position encoding을 더해줌
cat_x += positions
print('output : ', cat_x.shape)

output :  torch.Size([32, 197, 768])


In [17]:
torch.ones(2,2) + torch.ones(1)

tensor([[2., 2.],
        [2., 2.]])

In [18]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, 
                 emb_size: int = 768, img_size: int = 224):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'))
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions

        return x

PE = PatchEmbedding()
summary(PE, (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


In [19]:
#x = torch.randn(8, 3, 224, 224)
emb_size = 768
num_heads = 8

keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(keys, queries, values)

#print(x.shape)
x = PE(x)
print(queries(x).shape) # batch, n, emb_size

Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True)
torch.Size([32, 197, 768])


In [20]:
queries = rearrange(queries(x), "b n (h d) -> b h n d", h=num_heads) # -> batch, head, n, emb_size/head
queries.shape

torch.Size([32, 8, 197, 96])

In [21]:
keys = rearrange(keys(x), "b n (h d) -> b h n d", h=num_heads)
values  = rearrange(values(x), "b n (h d) -> b h n d", h=num_heads)

print('shape :', queries.shape, keys.shape, values.shape)

shape : torch.Size([32, 8, 197, 96]) torch.Size([32, 8, 197, 96]) torch.Size([32, 8, 197, 96])


In [22]:
# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy :', energy.shape)

energy : torch.Size([32, 8, 197, 197])


In [23]:
# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy/scaling, dim=-1) 
print('att :', att.shape)


att : torch.Size([32, 8, 197, 197])


In [24]:
# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
print('out :', out.shape)

out : torch.Size([32, 8, 197, 96])


In [25]:
# Rearrage to emb_size
out = rearrange(out, "b h n d -> b n (h d)")
print('out2 : ', out.shape)

out2 :  torch.Size([32, 197, 768])


In [26]:
from torchvision import transforms, datasets
img_size = 224

# Define image size
image_size = (img_size, img_size)  # Replace with your desired image dimensions

# Create data augmentation transforms
data_augmentation = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor (CHW format)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize pixel values (common practice)
    transforms.Resize(image_size),  # Resize image to specified dimensions
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=(-15, 15)),  # Random rotation with range -15 to 15 degrees
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.75, 1.3333))  # Random resized crop
])

BATCH_SIZE = 32
train_dataset = datasets.CIFAR100(root="./data/",
                                 train=True,
                                 download=True,
                                 transform=data_augmentation)

test_dataset = datasets.CIFAR100(root="./data/",
                                train=False,
                                download=True,
                                transform=data_augmentation)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

for (x, Y_train) in train_loader:
    print(f"X_train: {x.size()} type: {x.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break


Files already downloaded and verified
Files already downloaded and verified
X_train: torch.Size([32, 3, 224, 224]) type: torch.FloatTensor
Y_train: torch.Size([32]) type: torch.LongTensor


In [27]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
        
        
#x = torch.randn(8, 3, 224, 224)
PE = PatchEmbedding()
print(x.shape)
x = PE(x)
print(x.shape)
MHA = MultiHeadAttention()
summary(MHA, x.shape[1:], device='cpu')

torch.Size([32, 3, 224, 224])
torch.Size([32, 197, 768])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1            [-1, 197, 2304]       1,771,776
           Dropout-2          [-1, 8, 197, 197]               0
            Linear-3             [-1, 197, 768]         590,592
Total params: 2,362,368
Trainable params: 2,362,368
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.58
Forward/backward pass size (MB): 6.99
Params size (MB): 9.01
Estimated Total Size (MB): 16.57
----------------------------------------------------------------


In [28]:
x.shape

torch.Size([32, 197, 768])

In [29]:
qkv = MHA.qkv(x)
print(qkv.shape)
qkv = rearrange(qkv, "b n (h d qkv) -> (qkv) b h n d", h=8, qkv=3)
print(qkv.shape)

torch.Size([32, 197, 2304])
torch.Size([3, 32, 8, 197, 96])


In [30]:
from torchvision import transforms, datasets
img_size = 224

# Define image size
image_size = (img_size, img_size)  # Replace with your desired image dimensions

# Create data augmentation transforms
data_augmentation = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor (CHW format)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize pixel values (common practice)
    transforms.Resize(image_size),  # Resize image to specified dimensions
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=(-15, 15)),  # Random rotation with range -15 to 15 degrees
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.75, 1.3333))  # Random resized crop
])

BATCH_SIZE = 32
train_dataset = datasets.CIFAR100(root="./data/",
                                 train=True,
                                 download=True,
                                 transform=data_augmentation)

test_dataset = datasets.CIFAR100(root="./data/",
                                train=False,
                                download=True,
                                transform=data_augmentation)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

for (x, Y_train) in train_loader:
    print(f"X_train: {x.size()} type: {x.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break

Files already downloaded and verified
Files already downloaded and verified
X_train: torch.Size([32, 3, 224, 224]) type: torch.FloatTensor
Y_train: torch.Size([32]) type: torch.LongTensor


In [31]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

#x = torch.randn(8, 3, 224, 224)
x = PE(x)
x = MHA(x)
TE = TransformerEncoderBlock()
summary(TE, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1             [-1, 197, 768]           1,536
            Linear-2            [-1, 197, 2304]       1,771,776
           Dropout-3          [-1, 8, 197, 197]               0
            Linear-4             [-1, 197, 768]         590,592
MultiHeadAttention-5             [-1, 197, 768]               0
           Dropout-6             [-1, 197, 768]               0
       ResidualAdd-7             [-1, 197, 768]               0
         LayerNorm-8             [-1, 197, 768]           1,536
            Linear-9            [-1, 197, 3072]       2,362,368
             GELU-10            [-1, 197, 3072]               0
          Dropout-11            [-1, 197, 3072]               0
           Linear-12             [-1, 197, 768]       2,360,064
          Dropout-13             [-1, 197, 768]               0
      ResidualAdd-14             [-1, 1

In [32]:
#print(representation.shape) # torch.Size([8, 197, 768])

In [33]:
#cls_head = reduce(representation, 'b n e -> b e', reduction='mean')
#print(cls_head.shape)
#cls_head = nn.Linear(768, 1000)(cls_head)
#print(cls_head.shape)

In [34]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

tr_encoder = TransformerEncoder()
#print(tr_encoder)

In [35]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
        
summary(ViT(), (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 19

In [36]:
from torchvision import transforms, datasets
img_size = 224

# Define image size
image_size = (img_size, img_size)  # Replace with your desired image dimensions

# Create data augmentation transforms
data_augmentation = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor (CHW format)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize pixel values (common practice)
    transforms.Resize(image_size),  # Resize image to specified dimensions
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=(-15, 15)),  # Random rotation with range -15 to 15 degrees
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.75, 1.3333))  # Random resized crop
])

BATCH_SIZE = 32
train_dataset = datasets.CIFAR100(root="./data/",
                                 train=True,
                                 download=True,
                                 transform=data_augmentation)

test_dataset = datasets.CIFAR100(root="./data/",
                                train=False,
                                download=True,
                                transform=data_augmentation)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

for (x, Y_train) in train_loader:
    print(f"X_train: {x.size()} type: {x.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break

Files already downloaded and verified
Files already downloaded and verified
X_train: torch.Size([32, 3, 224, 224]) type: torch.FloatTensor
Y_train: torch.Size([32]) type: torch.LongTensor


In [None]:
vit = ViT()
out = vit(x)
print(out.shape)

In [None]:
#### cifar100_cnn

In [None]:
import ssl
import torch
import torch.nn as nn
from torchvision import transforms, datasets

# ssl._create_default_https_context = ssl._create_unverified_context


BATCH_SIZE = 32
train_dataset = datasets.CIFAR10(root="./data/",
                                 train=True,
                                 download=True,
                                 transform=transforms.ToTensor())

test_dataset = datasets.CIFAR10(root="./data/",
                                train=False,
                                download=True,
                                transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

print(train_loader.dataset)

for (X_train, Y_train) in train_loader:
    print(f"X_train: {X_train.size()} type: {X_train.type()}")
    print(f"Y_train: {Y_train.size()} type: {Y_train.type()}")
    break


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=8,
            kernel_size=3,
            padding=1)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            padding=1)
        self.pool = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )
        self.fc1 = nn.Linear(8 * 8 * 16, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool(x)

        x = x.view(-1, 8 * 8 * 16)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc3(x)
        x = torch.log_softmax(x, dim=1)
        return x


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Using PyTorch version: {torch.__version__}, Device: {DEVICE}")

model = CNN().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


def train(model, train_loader, optimizer, log_interval):
    model.train()
    for batch_idx, (image, label) in enumerate(train_loader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(
                f"train Epoch: {Epoch} [{batch_idx * len(image)}/{len(train_loader.dataset)}({100. * batch_idx / len(train_loader):.0f}%)]\tTrain Loss: {loss.item()}")


def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for image, label in test_loader:
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            test_loss += criterion(output, label).item()
            prediction = output.max(1, keepdim=True)[1]
            correct += prediction.eq(label.view_as(prediction)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

EPOCHS = 10
for Epoch in range(1, EPOCHS + 1):
    train(model, train_loader, optimizer, log_interval=200)
    test_loss, test_accuracy = evaluate(model, test_loader)
    print(f"\n[EPOCH: {Epoch}]\tTest Loss: {test_loss:.4f}\tTest Accuracy: {test_accuracy} % \n")

In [48]:
import torch
import sklearn
import numpy as np
from sklearn.metrics import top_k_accuracy_score
y_true = torch.tensor([0, 1, 2, 2]).cuda().cpu()
y_score = torch.tensor([[0.5, 0.2, 0.2],  # 0 is in top 2
                    [0.3, 0.4, 0.2],  # 1 is in top 2
                    [0.2, 0.4, 0.3],  # 2 is in top 2
                    [0.7, 0.2, 0.1]]).cuda().cpu() # 2 isn't in top 2
print(top_k_accuracy_score(y_true, y_score, k=1))
# Not normalizing gives the number of "correctly" classified samples
print(top_k_accuracy_score(y_true, y_score, k=1, normalize=False))

0.5
2


In [9]:
import torch
y_true = torch.tensor([0, 1, 2, 2]).cuda()
y_pred = torch.tensor([[0.5, 0.2, 0.2],  # 0 is in top 2
                    [0.3, 0.4, 0.2],  # 1 is in top 2
                    [0.2, 0.4, 0.3],  # 2 is in top 2
                    [0.7, 0.2, 0.1]]).cuda() # 2 isn't in top 2

k = 3

def top_k_accuracy(y_true, y_pred, k):
  """
  Calculates top-k accuracy.

  Args:
      y_true: Ground truth labels (one-hot encoded or integer).
      y_pred: Predicted probabilities (2D array).
      k: The value of k for top-k accuracy.

  Returns:
      Top-k accuracy as a float.
  """
  correct = 0
  for y_t, y_p in zip(y_true, y_pred):
    # Get top k predictions (indices with highest scores)
    top_k_indices = y_p.argsort()[-k:]

    # Check if true label is in top k predictions
    if y_t in top_k_indices:
      correct += 1

  accuracy = correct / len(y_true)
  return accuracy, correct

# Calculate top-k accuracy for k=5
top_k_accuracy, top_k_accuracy_score = top_k_accuracy(y_true, y_pred, k=k)

print(f"Top-{k} accuracy: {top_k_accuracy}")
print(f"Top-{k} accuracy: {top_k_accuracy_score}")

Top-3 accuracy: 1.0
Top-3 accuracy: 4
