In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import math
import time

In [None]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)

In [None]:
def positional_encoding(positions, d_model):

    angle_rads = torch.arange(positions, dtype=torch.float32).unsqueeze(1) * torch.pow(10000, -torch.arange(0, d_model, 2, dtype=torch.float32).float() / d_model)
    sines = torch.sin(angle_rads)
    cosines = torch.cos(angle_rads)

    pos_encoding = torch.zeros((positions, d_model), dtype=torch.float32)
    pos_encoding[:, 0::2] = sines
    pos_encoding[:, 1::2] = cosines

    return pos_encoding.unsqueeze(0)

In [None]:
class FastAttention(nn.Module):
    def __init__(self, d_model, n_heads, m):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        assert self.d_model == self.head_dim * n_heads, "d_model must be divisible by n_heads"

        self.scale = math.sqrt(self.head_dim)  # Scaling factor
        self.m = m  # Random Features

        self.qkv = nn.Linear(d_model, d_model * 3)
        self.projection = nn.Linear(d_model, d_model)

        # Intialize the random matrix for mapping
        self.register_buffer("R", torch.randn(self.head_dim, m))

    def forward(self, x):
        B, N, _ = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [each.reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3) for each in qkv]

        q = q * self.scale
        k = k * self.scale

        # Computing q' and k' using random feature mapping
        q_prime = torch.einsum('bnhd,dk->bhkn', q, self.R)
        k_prime = torch.einsum('bnhd,dk->bhkn', k, self.R)

        k_prime = F.softmax(k_prime, dim=-1)

        v_prime = torch.einsum('bhkn,bnhd->bhkd', k_prime, v)
        y = torch.einsum('bhkn,bhkd->bnhd', q_prime, v_prime)

        y = y.permute(0, 2, 1, 3).reshape(B, N, self.d_model)
        y = self.projection(y)

        return y

In [None]:
class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches, n_blocks, hidden_d, n_heads, out_d, m_features):
        super().__init__()
        self.patch_size = (input_shape[1] // n_patches, input_shape[2] // n_patches)
        self.d_model = hidden_d
        num_pixels_per_patch = self.patch_size[0] * self.patch_size[1] * input_shape[0]
        self.embedding = nn.Linear(num_pixels_per_patch, self.d_model)
        self.position_embedding = positional_encoding(n_patches * n_patches, self.d_model)
        self.blocks = nn.ModuleList([
            FastAttention(self.d_model, n_heads, m_features) for _ in range(n_blocks)
        ])
        self.to_cls_token = nn.Identity()
        self.classifier = nn.Linear(self.d_model, out_d)

    def forward(self, x):
        B, C, H, W = x.shape
        # Unfold the image into patches
        x = x.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        # Flatten the patches
        x = x.contiguous().view(B, -1, self.patch_size[0] * self.patch_size[1] * C)
        x = self.embedding(x)
        x += self.position_embedding.to(x.device)  # Add absolute position embedding
        x = self.to_cls_token(x)
        # Apply attention blocks
        for block in self.blocks:
            x = block(x)
        x = x.mean(dim=1)
        return self.classifier(x)

In [None]:
def main():
  # Load data
  transform = ToTensor()

  train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
  test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

  train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
  test_loader = DataLoader(test_set, shuffle=False, batch_size=32)

  # Define model and training options
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Using device: ", device, f"{torch.cuda.get_device_name(device)}" if torch.cuda.is_available() else "")
  model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10, m_features=128)
  model = model.to(device)
  N_EPOCHS=10
  LR=0.02

  # Training loop
  optimizer = Adam(model.parameters(), lr=LR)
  criterion = CrossEntropyLoss()

  train_start_time = time.time()
  for epoch in tqdm(range(N_EPOCHS)):
    train_loss = 0.0
    for x, y in train_loader:
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)

      train_loss += loss.detach().cpu().item() / len(train_loader)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch+1}/{N_EPOCHS} loss: {train_loss:.2f}")

  epoch_duration = time.time() - train_start_time
  print(f"Training time is: {epoch_duration:.2f} sec")

  # Testing loop
  test_start_time = time.time()
  with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for x, y in tqdm(test_loader, desc="Testing"):
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)
      test_loss += loss.detach().cpu().item() / len(test_loader)

      correct += torch.sum(torch.argmax(y_hat, dim=1)==y).detach().cpu().item()
      total += len(x)

    print(f"Test loss: {test_loss:2f}")
    print(f"Test accuracy: {correct/total*100:.2f}%")

  epoch_duration = time.time() - test_start_time
  print(f"Test time is: {epoch_duration:.2f} sec")

In [None]:
if __name__ == "__main__":
    main()

Using device:  cpu 


 10%|█         | 1/10 [02:39<23:52, 159.14s/it]

Epoch 1/10 loss: 39.93


 20%|██        | 2/10 [05:18<21:13, 159.13s/it]

Epoch 2/10 loss: 1.67


 30%|███       | 3/10 [07:58<18:36, 159.44s/it]

Epoch 3/10 loss: 1.31


 40%|████      | 4/10 [10:37<15:56, 159.34s/it]

Epoch 4/10 loss: 1.11


 50%|█████     | 5/10 [13:16<13:16, 159.28s/it]

Epoch 5/10 loss: 1.00


 60%|██████    | 6/10 [15:54<10:35, 158.91s/it]

Epoch 6/10 loss: 0.91


 70%|███████   | 7/10 [18:33<07:56, 158.80s/it]

Epoch 7/10 loss: 0.85


 80%|████████  | 8/10 [21:12<05:17, 158.95s/it]

Epoch 8/10 loss: 0.81


 90%|█████████ | 9/10 [23:51<02:38, 158.83s/it]

Epoch 9/10 loss: 0.75


100%|██████████| 10/10 [26:29<00:00, 158.98s/it]


Epoch 10/10 loss: 0.70
Training time is: 1589.82 sec


Testing: 100%|██████████| 313/313 [00:16<00:00, 18.77it/s]

Test loss: 0.856483
Test accuracy: 72.84%
Test time is: 16.69 sec



