In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import math
import time

In [None]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)

In [None]:
def positional_encoding(positions, d_model):

    angle_rads = torch.arange(positions, dtype=torch.float32).unsqueeze(1) * torch.pow(10000, -torch.arange(0, d_model, 2, dtype=torch.float32).float() / d_model)
    sines = torch.sin(angle_rads)
    cosines = torch.cos(angle_rads)

    pos_encoding = torch.zeros((positions, d_model), dtype=torch.float32)
    pos_encoding[:, 0::2] = sines
    pos_encoding[:, 1::2] = cosines

    return pos_encoding.unsqueeze(0)

In [None]:
def relu_feature_map(x):
    return torch.nn.functional.relu(x)  # Using ReLU as the feature map

class PerformerAttention(nn.Module):
    def __init__(self, d_model, n_heads, causal=False):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.causal = causal

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, _ = x.shape
        Q = self.query(x).reshape(B, N, self.n_heads, self.head_dim)
        K = self.key(x).reshape(B, N, self.n_heads, self.head_dim)
        V = self.value(x).reshape(B, N, self.n_heads, self.head_dim)

        Q = relu_feature_map(Q)  # Apply ReLU feature map to queries
        K = relu_feature_map(K)  # Apply ReLU feature map to keys

        # Compute attention using basic dot product (since we are not using random features)
        D = torch.einsum('bhnd,bhmd->bhnm', Q, K)
        attn = torch.softmax(D / math.sqrt(self.head_dim), dim=-1)

        y = torch.einsum('bhnm,bhmd->bhnd', attn, V).reshape(B, N, -1)
        return y

class MyPerformerViT(nn.Module):
    def __init__(self, input_shape, n_patches, n_blocks, hidden_d, n_heads, out_d):
        super().__init__()
        self.patch_size = (input_shape[1] // n_patches, input_shape[2] // n_patches)
        self.d_model = hidden_d
        num_pixels_per_patch = self.patch_size[0] * self.patch_size[1] * input_shape[0]
        self.embedding = nn.Linear(num_pixels_per_patch, self.d_model)
        self.position_embedding = positional_encoding(n_patches * n_patches, self.d_model)
        self.blocks = nn.ModuleList([
            PerformerAttention(self.d_model, n_heads) for _ in range(n_blocks)
        ])
        self.to_cls_token = nn.Identity()
        self.classifier = nn.Linear(self.d_model, out_d)

    def forward(self, x):
        B, C, H, W = x.shape
        # Unfold the image into patches
        x = x.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        # Flatten the patches
        x = x.contiguous().view(B, -1, self.patch_size[0] * self.patch_size[1] * C)
        x = self.embedding(x)
        x += self.position_embedding.to(x.device)  # Add absolute position embedding
        x = self.to_cls_token(x)
        # Apply attention blocks
        for block in self.blocks:
            x = block(x)
        x = x.mean(dim=1)
        return self.classifier(x)

In [None]:
def main():
  # Load data
  transform = ToTensor()

  train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
  test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

  train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
  test_loader = DataLoader(test_set, shuffle=False, batch_size=32)

  # Define model and training options
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Using device: ", device, f"{torch.cuda.get_device_name(device)}" if torch.cuda.is_available() else "")
  model = MyPerformerViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10)
  model = model.to(device)
  N_EPOCHS=10
  LR=0.02

  # Training loop
  optimizer = Adam(model.parameters(), lr=LR)
  criterion = CrossEntropyLoss()

  train_start_time = time.time()
  for epoch in tqdm(range(N_EPOCHS)):
    train_loss = 0.0
    for x, y in train_loader:
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)

      train_loss += loss.detach().cpu().item() / len(train_loader)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch+1}/{N_EPOCHS} | loss: {train_loss:.2f}")

  epoch_duration = time.time() - train_start_time
  print(f"Training time is: {epoch_duration:.2f} sec")

  # Testing loop
  test_start_time = time.time()
  with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for x, y in tqdm(test_loader, desc="Testing"):
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)
      test_loss += loss.detach().cpu().item() / len(test_loader)

      correct += torch.sum(torch.argmax(y_hat, dim=1)==y).detach().cpu().item()
      total += len(x)

    print(f"Test loss: {test_loss:2f}")
    print(f"Test accuracy: {correct/total*100:.2f}%")

  epoch_duration = time.time() - test_start_time
  print(f"Test time is: {epoch_duration:.2f} sec")

In [None]:
if __name__ == "__main__":
    main()

Using device:  cpu 


 10%|█         | 1/10 [00:16<02:26, 16.31s/it]

Epoch 1/10 | loss: 1.73


 20%|██        | 2/10 [00:32<02:07, 15.96s/it]

Epoch 2/10 | loss: 1.07


 30%|███       | 3/10 [00:47<01:50, 15.83s/it]

Epoch 3/10 | loss: 0.87


 40%|████      | 4/10 [01:03<01:35, 15.84s/it]

Epoch 4/10 | loss: 0.78


 50%|█████     | 5/10 [01:19<01:19, 15.94s/it]

Epoch 5/10 | loss: 0.71


 60%|██████    | 6/10 [01:35<01:03, 15.88s/it]

Epoch 6/10 | loss: 0.68


 70%|███████   | 7/10 [01:51<00:47, 15.89s/it]

Epoch 7/10 | loss: 0.65


 80%|████████  | 8/10 [02:07<00:32, 16.04s/it]

Epoch 8/10 | loss: 0.63


 90%|█████████ | 9/10 [02:23<00:15, 15.88s/it]

Epoch 9/10 | loss: 0.62


100%|██████████| 10/10 [02:38<00:00, 15.87s/it]


Epoch 10/10 | loss: 0.58
Training time is: 158.73 sec


Testing: 100%|██████████| 313/313 [00:02<00:00, 143.53it/s]

Test loss: 0.578677
Test accuracy: 81.32%
Test time is: 2.19 sec



