In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn, optim

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from datetime import datetime
from tqdm import tqdm
import time
import numpy as np

# # Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# ========== Helper Functions ==========

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def convert_to_superclass(targets):
    # For CIFAR100, convert 100 classes to 20 superclasses
    superclass_labels = np.array([
         4,  1, 14,  8,  0,  6,  7,  7, 18,  3,
         3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
         6, 11,  5, 10,  7,  6, 13, 15,  3, 15,
         0, 11,  1, 10, 12, 14, 16,  9, 11,  5,
         5, 19,  8,  8, 15, 13, 14, 17, 18, 10,
        16,  4, 17,  4,  2,  0, 17,  4, 18, 17,
        10,  3,  2, 12, 12, 16, 12,  1,  9, 19,
         2, 10,  0,  1, 16, 12,  9, 13, 15, 13,
        16, 19,  2,  4,  6, 19,  5,  5,  8, 19,
        18,  1,  2, 15,  6,  0, 17,  8, 14, 13
        ])
    return superclass_labels[targets]

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# ========== Main ViT Logic ==========
# This code is modified from regular ViT implementation: https://github.com/lucidrains/vit-pytorch

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class PerformerAttention(nn.Module):
    # This PerformerAttention is modified from regular Attention.
    # Different attention kernels are available in kernel(self, x).

    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., model_type='performer_relu',
            explu_exp_weight=None, explu_relu_weight=None
            ):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.model_type = model_type

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.explu_exp_weight, self.explu_relu_weight = explu_exp_weight, explu_relu_weight

        self.kernel = None
        if self.model_type == 'performer_relu':
            self.kernel = self.kernel_relu
        elif self.model_type == 'performer_exp':
            self.kernel = self.kernel_exp
        elif self.model_type == 'performer_explu':
            self.kernel = self.kernel_explu

    def kernel_relu(self, x):
        return torch.nn.functional.relu(x)

    def kernel_exp(self, x):
        return torch.exp(x)

    def kernel_explu(self, x):
        return self.explu_exp_weight * torch.exp(x) + self.explu_relu_weight * torch.nn.functional.relu(x)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q = self.kernel(q)
        k = self.kernel(k)
        out = torch.matmul(k.transpose(-1, -2), v)
        out = torch.matmul(q, out) * self.scale

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    # The Transformer logic is modified to choose between regular attention and perfomer attention based on model_type

    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., model_type='regular_transformer',
            explu_exp_weight=None, explu_relu_weight=None
            ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])

        if model_type == 'regular_transformer':
            # Regular Attention
            for _ in range(depth):
                self.layers.append(nn.ModuleList([
                    Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                    FeedForward(dim, mlp_dim, dropout = dropout)
                ]))
        else:
            # Performer Attention
            for _ in range(depth):
                self.layers.append(nn.ModuleList([
                    PerformerAttention(
                        dim, heads = heads, dim_head = dim_head, dropout = dropout, model_type=model_type,
                        explu_exp_weight=explu_exp_weight, explu_relu_weight=explu_relu_weight),
                    FeedForward(dim, mlp_dim, dropout = dropout)
                ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)


class ViT(nn.Module):
    def __init__(
            self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls',
            channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.,
            model_type='regurlar_transformer', explu_exp_weight_init=-0.2, explu_relu_weight_init=0.8):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.explu_exp_weight = nn.Parameter(torch.tensor(explu_exp_weight_init)) if model_type == 'performer_explu' else None
        self.explu_relu_weight = nn.Parameter(torch.tensor(explu_relu_weight_init)) if model_type == 'performer_explu' else None

        self.transformer = Transformer(
            dim, depth, heads, dim_head, mlp_dim, dropout, model_type,
            explu_exp_weight=self.explu_exp_weight, explu_relu_weight=self.explu_relu_weight
            )

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [None]:
# ========== Train and Test ==========

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    pbar = tqdm(loader)
    total_loss = 0
    train_count = 0
    correct = 0

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accuracy
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        train_count += 1
        pbar.set_description(f"Loss : {total_loss/train_count:.4f}")

    avg_loss = total_loss / len(loader)
    accuracy = correct / len(loader.dataset)
    return avg_loss, accuracy

def test_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Accuracy
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()

    avg_loss = total_loss / len(loader)
    accuracy = correct / len(loader.dataset)
    return avg_loss, accuracy


In [None]:
# ========== Experiment Logic ==========

def run_experitment(model, train_loader, test_loader, optimizer, criterion, device, run_name, epochs_to_run, checkpoint):

    epoch_stat, current_epoch = {}, 0
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch_stat = checkpoint['stat']
        current_epoch = list(epoch_stat.keys())[-1]
        print(f"Loaded checkpoint at epoch {current_epoch}")

    # Training process
    for epoch in range(current_epoch + 1, current_epoch + epochs_to_run + 1):
        print(f"Epoch {epoch}/{current_epoch + epochs_to_run}")

        # Training
        train_start_time = time.time()
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        if device == 'cuda': torch.cuda.synchronize()
        train_time = time.time() - train_start_time

        # Testing
        test_start_time = time.time()
        test_loss, test_acc = test_epoch(model, test_loader, criterion, device)
        if device == 'cuda': torch.cuda.synchronize()
        test_time = time.time() - test_start_time

        # Display
        print(f"Train Time: {train_time:2f}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
        print(f"Test Time: {test_time:2f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
        print()

        # Update epoch_stat
        epoch_stat[epoch] = {
            'train_time': train_time, 'train_loss': train_loss, 'train_acc': train_acc,
            'test_time': test_time, 'test_loss': test_loss, 'test_acc': test_acc
        }

        # Write to checkpoint
        if (epoch % 5 == 0) or (epoch == current_epoch + epochs_to_run):
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'stat': epoch_stat,
            }, "./checkpoint_" + run_name)


In [None]:
# ========== Experiment Parameters ==========

epochs_to_run = 100

model_type = 'performer_explu'
# Possible values: 'regular_transformer', 'performer_relu', 'performer_exp', 'performer_explu'

dataset_name = 'MNIST'
# Possible values: 'MNIST', 'CIFAR10', 'CIFAR100'

run_name = "{}_{}".format(model_type, dataset_name)

checkpoint = None
# # Load checkpoint if needed
# checkpoint = torch.load("./checkpoint_" + run_name, weights_only=True)

In [None]:
# ========== Load Dataset, based on Experiment Param ==========

train_dataset, test_dataset, train_loader, test_loader = None, None, None, None

if dataset_name == 'MNIST':
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
elif dataset_name == 'CIFAR10':
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
elif dataset_name == 'CIFAR100':
    train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR100(root='./data', train=False, transform=transform, download=True)
    train_dataset.targets = convert_to_superclass(train_dataset.targets)
    test_dataset.targets = convert_to_superclass(test_dataset.targets)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
# ========== Load Model, based on Experiment Param ==========

torch.manual_seed(0)
model = None

if dataset_name in ['MNIST']:
    model = ViT(
        image_size=28, channels=1, num_classes=10,
        patch_size=4, dim=8, dim_head=8,
        depth=6, heads=8, mlp_dim=256,
        dropout=0.1, emb_dropout=0.1, model_type=model_type
    ).to(device)
elif dataset_name in ['CIFAR10']:
    model = ViT(
        image_size=32, channels=3, num_classes=10,
        patch_size=4, dim=32, dim_head=16,
        depth=6, heads=16, mlp_dim=256,
        dropout=0.1, emb_dropout=0.1, model_type=model_type
    ).to(device)
elif dataset_name in ['CIFAR100']:
    model = ViT(
        image_size=32, channels=3, num_classes=20,
        patch_size=4, dim=32, dim_head=16,
        depth=6, heads=32, mlp_dim=256,
        dropout=0.1, emb_dropout=0.1, model_type=model_type
    ).to(device)

# Load Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# ========== Run experiment on selected model and dataset ==========
# Experiment results are also exported to checkpoints.

run_experitment(model, train_loader, test_loader, optimizer, criterion, device, run_name, epochs_to_run, checkpoint)

# The Following cell is for measuring Inference Time.

In [None]:
# ========== Test Inference Time ==========

dataset_name = 'CIFAR100'
# Possible values: 'MNIST', 'CIFAR10', 'CIFAR100'

test_input, test_model, inference_time = None, {}, {}
model_type_ls = ['regular_transformer', 'performer_relu', 'performer_exp', 'performer_explu']

# Create models and dummy image
if dataset_name in ['MNIST']:
    for model_type in model_type_ls:
        test_model[model_type] = ViT(
            image_size=28, channels=1, num_classes=10,
            patch_size=4, dim=8, dim_head=8,
            depth=6, heads=8, mlp_dim=256,
            dropout=0.1, emb_dropout=0.1, model_type=model_type
        ).to('cpu')
        inference_time[model_type] = []
    test_input = torch.randn(1, 1, 28, 28).to('cpu')
elif dataset_name in ['CIFAR10']:
    for model_type in model_type_ls:
        test_model[model_type] = ViT(
            image_size=32, channels=3, num_classes=10,
            patch_size=4, dim=32, dim_head=16,
            depth=6, heads=16, mlp_dim=256,
            dropout=0.1, emb_dropout=0.1, model_type=model_type
        ).to('cpu')
        inference_time[model_type] = []
    test_input = torch.randn(1, 3, 32, 32).to('cpu')
elif dataset_name in ['CIFAR100']:
    for model_type in model_type_ls:
        test_model[model_type] = ViT(
            image_size=32, channels=3, num_classes=20,
            patch_size=4, dim=32, dim_head=16,
            depth=6, heads=32, mlp_dim=256,
            dropout=0.1, emb_dropout=0.1, model_type=model_type
        ).to('cpu')
        inference_time[model_type] = []
    test_input = torch.randn(1, 3, 32, 32).to('cpu')

# Warm-up
for i in range(100):
    this_model_type = model_type_ls[i % len(model_type_ls)]
    _ = test_model[this_model_type](test_input)

# Measure Inference Time
num_iterations = 40000
for i in range(num_iterations):
    this_model_type = model_type_ls[i % len(model_type_ls)]
    start_time = time.time()
    _ = test_model[this_model_type](test_input)
    end_time = time.time()
    inference_time[this_model_type] += [end_time - start_time]

# Display Result
for model_type in model_type_ls:
    print(f"{model_type}: {np.mean(inference_time[model_type]):.6f} seconds")

# The following cells are for checking number of parameters and parameter values.

In [None]:
# ========== Show number of Parameters ==========

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

In [None]:
# ========== Check Learnable Param value for Kernel Feature ==========

if 'performer_explu' in model_type:
    for name, param in model.named_parameters():
        if any(param in name for param in ['explu_exp_weight', 'explu_relu_weight']):
            print(name, param.data)