In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import math
import time

In [None]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)

In [None]:
def positional_encoding(positions, d_model):

    angle_rads = torch.arange(positions, dtype=torch.float32).unsqueeze(1) * torch.pow(10000, -torch.arange(0, d_model, 2, dtype=torch.float32).float() / d_model)
    sines = torch.sin(angle_rads)
    cosines = torch.cos(angle_rads)

    pos_encoding = torch.zeros((positions, d_model), dtype=torch.float32)
    pos_encoding[:, 0::2] = sines
    pos_encoding[:, 1::2] = cosines

    return pos_encoding.unsqueeze(0)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        assert self.d_model == self.head_dim * self.n_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, _ = x.shape
        q = self.query(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.key(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 3, 1)
        v = self.value(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        attn = torch.matmul(q, k) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        y = torch.matmul(attn, v)
        y = y.permute(0, 2, 1, 3).reshape(B, N, -1)

        return y

In [None]:
class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches, n_blocks, hidden_d, n_heads, out_d):
        super().__init__()
        self.patch_size = (input_shape[1] // n_patches, input_shape[2] // n_patches)
        self.d_model = hidden_d
        num_pixels_per_patch = self.patch_size[0] * self.patch_size[1] * input_shape[0]
        self.embedding = nn.Linear(num_pixels_per_patch, self.d_model)
        self.position_embedding = positional_encoding(n_patches * n_patches, self.d_model)
        self.blocks = nn.ModuleList([
            MultiHeadAttention(self.d_model, n_heads) for _ in range(n_blocks)
        ])
        self.to_cls_token = nn.Identity()
        self.classifier = nn.Linear(self.d_model, out_d)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        x = x.contiguous().view(B, -1, self.patch_size[0] * self.patch_size[1] * C)
        x = self.embedding(x)
        x += self.position_embedding.to(x.device)  # Add absolute position embedding
        x = self.to_cls_token(x)
        for block in self.blocks:
            x = block(x)
        x = x.mean(dim=1)
        return self.classifier(x)

In [None]:
def main():
  # Load data
  transform = ToTensor()

  train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
  test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

  train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
  test_loader = DataLoader(test_set, shuffle=False, batch_size=32)

  # Define model and training options
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Using device: ", device, f"{torch.cuda.get_device_name(device)}" if torch.cuda.is_available() else "")
  model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10)
  model = model.to(device)
  N_EPOCHS=10
  LR=0.02

  # Training loop
  optimizer = Adam(model.parameters(), lr=LR)
  criterion = CrossEntropyLoss()

  train_start_time = time.time()
  for epoch in tqdm(range(N_EPOCHS)):
    train_loss = 0.0
    for x, y in train_loader:
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)

      train_loss += loss.detach().cpu().item() / len(train_loader) # detach() not affect gradient

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch+1}/{N_EPOCHS} | loss: {train_loss:.2f}")

  epoch_duration = time.time() - train_start_time
  print(f"Training time is: {epoch_duration:.2f} sec")

  # Testing loop
  test_start_time = time.time()
  with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for x, y in tqdm(test_loader, desc="Testing"):
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)
      test_loss += loss.detach().cpu().item() / len(test_loader)

      correct += torch.sum(torch.argmax(y_hat, dim=1)==y).detach().cpu().item()
      total += len(x)

    print(f"Test loss: {test_loss:2f}")
    print(f"Test accuracy: {correct/total*100:.2f}%")

  epoch_duration = time.time() - test_start_time
  print(f"Test time is: {epoch_duration:.2f} sec")

In [None]:
if __name__ == "__main__":
    main()

Using device:  cpu 


 10%|█         | 1/10 [00:20<03:08, 20.96s/it]

Epoch 1/10 | loss: 1.52


 20%|██        | 2/10 [00:42<02:50, 21.33s/it]

Epoch 2/10 | loss: 0.94


 30%|███       | 3/10 [01:04<02:30, 21.48s/it]

Epoch 3/10 | loss: 0.77


 40%|████      | 4/10 [01:25<02:08, 21.40s/it]

Epoch 4/10 | loss: 0.62


 50%|█████     | 5/10 [01:46<01:46, 21.33s/it]

Epoch 5/10 | loss: 0.51


 60%|██████    | 6/10 [02:07<01:25, 21.26s/it]

Epoch 6/10 | loss: 0.45


 70%|███████   | 7/10 [02:29<01:03, 21.31s/it]

Epoch 7/10 | loss: 0.42


 80%|████████  | 8/10 [02:51<00:42, 21.48s/it]

Epoch 8/10 | loss: 0.40


 90%|█████████ | 9/10 [03:12<00:21, 21.60s/it]

Epoch 9/10 | loss: 0.39


100%|██████████| 10/10 [03:34<00:00, 21.47s/it]


Epoch 10/10 | loss: 0.39
Training time is: 214.76 sec


Testing: 100%|██████████| 313/313 [00:02<00:00, 112.72it/s]

Test loss: 0.420764
Test accuracy: 85.93%
Test time is: 2.79 sec



