In [4]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("used device :", device)


used device : cuda


In [5]:
def load_mnist_images(path):
    with open(path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
        data = data.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0
    return data

def load_mnist_labels(path):
    with open(path, 'rb') as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)
    return labels

train_images = load_mnist_images("/kaggle/input/train-images.idx3-ubyte")
train_labels = load_mnist_labels("/kaggle/input/train-labels.idx1-ubyte")

test_images = load_mnist_images("/kaggle/input/t10k-images.idx3-ubyte")
test_labels = load_mnist_labels("/kaggle/input/t10k-labels.idx1-ubyte")

print("Train images :", train_images.shape)
print("Test images  :", test_images.shape)


Train images : (60000, 1, 28, 28)
Test images  : (10000, 1, 28, 28)


In [6]:
class MnistDataset(Dataset):
    def __init__(self, images, labels):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

train_dataset = MnistDataset(train_images, train_labels)
test_dataset  = MnistDataset(test_images, test_labels)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)


In [7]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=28, patch_size=7, emb_dim=128):
        super().__init__()
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_channels=1,
            out_channels=emb_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, emb_dim=128, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads

        self.query = nn.Linear(emb_dim, emb_dim)
        self.key   = nn.Linear(emb_dim, emb_dim)
        self.value = nn.Linear(emb_dim, emb_dim)

        self.out = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        B, N, D = x.shape
        H = self.num_heads

        q = self.query(x).reshape(B, N, H, D//H).transpose(1, 2)
        k = self.key(x).reshape(B, N, H, D//H).transpose(1, 2)
        v = self.value(x).reshape(B, N, H, D//H).transpose(1, 2)

        scores = (q @ k.transpose(-2, -1)) / (D**0.5)
        att = torch.softmax(scores, dim=-1)
        out = att @ v

        out = out.transpose(1, 2).reshape(B, N, D)
        return self.out(out)


class TransformerEncoderBlock(nn.Module):
    def __init__(self, emb_dim=128, num_heads=4, mlp_ratio=2):
        super().__init__()
        self.ln1 = nn.LayerNorm(emb_dim)
        self.mhsa = MultiHeadSelfAttention(emb_dim, num_heads)

        self.ln2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, emb_dim * mlp_ratio),
            nn.ReLU(),
            nn.Linear(emb_dim * mlp_ratio, emb_dim)
        )

    def forward(self, x):
        x = x + self.mhsa(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class ViT(nn.Module):
    def __init__(self, img_size=28, patch_size=7, emb_dim=128, depth=4, num_heads=4, num_classes=10):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, emb_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim))

        self.encoder = nn.Sequential(
            *[TransformerEncoderBlock(emb_dim, num_heads) for _ in range(depth)]
        )

        self.head = nn.Linear(emb_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)

        x = x + self.pos_embed

        x = self.encoder(x)

        cls_output = x[:, 0]
        return self.head(cls_output)


model_vit = ViT().to(device)
print(model_vit)


ViT(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(1, 128, kernel_size=(7, 7), stride=(7, 7))
  )
  (encoder): Sequential(
    (0): TransformerEncoderBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mhsa): MultiHeadSelfAttention(
        (query): Linear(in_features=128, out_features=128, bias=True)
        (key): Linear(in_features=128, out_features=128, bias=True)
        (value): Linear(in_features=128, out_features=128, bias=True)
        (out): Linear(in_features=128, out_features=128, bias=True)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
      )
    )
    (1): TransformerEncoderBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mhsa): MultiHeadSelfAttention(
        (query): Linear(in_features=128, out_featur

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_vit.parameters(), lr=0.001)

num_epochs = 10

for epoch in range(num_epochs):
    model_vit.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model_vit(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch {epoch+1}/{num_epochs} | Loss={total_loss/total:.4f} | Acc={(correct/total)*100:.2f}%")


Epoch 1/10 | Loss=0.5353 | Acc=82.27%
Epoch 2/10 | Loss=0.1514 | Acc=95.27%
Epoch 3/10 | Loss=0.1096 | Acc=96.60%
Epoch 4/10 | Loss=0.0887 | Acc=97.17%
Epoch 5/10 | Loss=0.0765 | Acc=97.51%
Epoch 6/10 | Loss=0.0689 | Acc=97.82%
Epoch 7/10 | Loss=0.0615 | Acc=98.00%
Epoch 8/10 | Loss=0.0566 | Acc=98.13%
Epoch 9/10 | Loss=0.0509 | Acc=98.36%
Epoch 10/10 | Loss=0.0521 | Acc=98.30%


In [9]:
model_vit.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model_vit(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = (np.array(all_preds) == np.array(all_labels)).mean() * 100
f1 = f1_score(all_labels, all_preds, average="macro")

print("Accuracy ViT :", round(accuracy, 2), "%")
print("F1-score     :", round(f1, 4))


Accuracy ViT : 97.85 %
F1-score     : 0.9783
