In [4]:
!pip install performer-pytorch
!pip install torch
!pip install numpy
!pip install tqdm

Collecting performer-pytorch
  Downloading performer_pytorch-1.1.4-py3-none-any.whl.metadata (763 bytes)
Collecting local-attention>=1.1.1 (from performer-pytorch)
  Downloading local_attention-1.9.15-py3-none-any.whl.metadata (683 bytes)
Collecting axial-positional-embedding>=0.1.0 (from performer-pytorch)
  Downloading axial_positional_embedding-0.2.1.tar.gz (2.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading performer_pytorch-1.1.4-py3-none-any.whl (13 kB)
Downloading local_attention-1.9.15-py3-none-any.whl (9.0 kB)
Building wheels for collected packages: axial-positional-embedding
  Building wheel for axial-positional-embedding (setup.py) ... [?25l[?25hdone
  Created wheel for axial-positional-embedding: filename=axial_positional_embedding-0.2.1-py3-none-any.whl size=2887 sha256=0526021ccd041d0e8ec84eaad899f8c2caefabedd0c30303839019a6a366712e
  Stored in directory: /root/.cache/pip/wheels/b1/cb/39/7ce7ff2d2fd37cfe1fe7b3a3c43cf410632b2ad3b3f3986d73
Successful

In [5]:
def performer_exponential_kernel(data, is_query=True, normalize=False, eps=1e-6):
    """
    Exponential kernel for Performer attention mechanism.

    Args:
        data: Input tensor
        is_query: Boolean indicating if input is query (True) or key (False)
        normalize: Whether to normalize the output
        eps: Small constant for numerical stability
    """
    data_norm = torch.norm(data, p=2, dim=-1, keepdim=True)
    data_normalized = data / (data_norm + eps)

    if normalize:
        return data_normalized

    return torch.exp(-data_norm) * data_normalized if is_query else torch.exp(data_norm) * data_normalized

In [None]:
import sys
import os
sys.path.append('/content')
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import tqdm
from performers_pytorch import PerformerLM
from autoregressive_wrapper import AutoregressiveWrapper

# Define the exponential kernel function
def performer_exponential_kernel(data, is_query=True, normalize=False, eps=1e-6):
    """
    Exponential kernel for Performer attention mechanism.

    Args:
        data: Input tensor
        is_query: Boolean indicating if input is query (True) or key (False)
        normalize: Whether to normalize the output
        eps: Small constant for numerical stability
    """
    data_norm = torch.norm(data, p=2, dim=-1, keepdim=True)
    data_normalized = data / (data_norm + eps)

    if normalize:
        return data_normalized

    return torch.exp(-data_norm) * data_normalized if is_query else torch.exp(data_norm) * data_normalized

# Constants
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load and preprocess MNIST
data = np.load('mnist.npz')
x_train = torch.from_numpy(data['x_train']).float()
y_train = torch.from_numpy(data['y_train']).long()
x_test = torch.from_numpy(data['x_test']).float()
y_test = torch.from_numpy(data['y_test']).long()

# Normalize and reshape
x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.view(-1, 784)  # Flatten 28x28 to 784
x_test = x_test.view(-1, 784)

# Create dataloaders
train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = torch.utils.data.TensorDataset(x_test, y_test)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Initialize model
model = PerformerLM(
    num_tokens=256,  # Number of unique tokens (pixel values)
    dim=256,
    depth=4,
    max_seq_len=784,  # MNIST flattened size
    heads=4,
    causal=False,
    reversible=True,
    use_scalenorm=True,
    generalized_attention=True,
    kernel_fn=performer_exponential_kernel,  # Use the defined kernel function
    local_attn_heads=(4, 4, 2, 2),
    no_projection=True  # Disable projection if not needed
).to(DEVICE)

# Add classification head
classifier = nn.Sequential(
    nn.Linear(256, 10),  # 10 classes for MNIST
    nn.LogSoftmax(dim=1)
).to(DEVICE)

# Optimizer and loss
optimizer = torch.optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(EPOCHS):  # Start from 0
    model.train()
    classifier.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    # Training
    for batch_idx, (data, target) in enumerate(tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1} Training')):
        data, target = data.to(DEVICE), target.to(DEVICE)

        # Convert to indices for Performer
        data = (data * 255).clamp(0, 255).long()  # Ensure data is in the correct range

        # Forward pass
        features = model(data)
        features = features.mean(dim=1)  # Global average pooling
        output = classifier(features)

        loss = criterion(output, target)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        pred = output.argmax(dim=1)
        train_correct += pred.eq(target).sum().item()
        train_total += target.size(0)
        train_loss += loss.item()

    # Validation
    model.eval()
    classifier.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for data, target in tqdm.tqdm(val_loader, desc='Validation'):
            data, target = data.to(DEVICE), target.to(DEVICE)
            data = (data * 255).clamp(0, 255).long()  # Ensure data is in the correct range

            features = model(data)
            features = features.mean(dim=1)
            output = classifier(features)

            val_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            val_correct += pred.eq(target).sum().item()
            val_total += target.size(0)

    # Print metrics
    train_loss /= len(train_loader)
    train_acc = train_correct / train_total
    val_loss /= len(val_loader)
    val_acc = val_correct / val_total

    print(f'Epoch: {epoch+1}')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\n')

Epoch 1 Training: 100%|██████████| 469/469 [19:59<00:00,  2.56s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.50it/s]


Epoch: 1
Train Loss: 1.9722, Train Acc: 0.2536
Val Loss: 1.7660, Val Acc: 0.3025



Epoch 2 Training: 100%|██████████| 469/469 [20:06<00:00,  2.57s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.50it/s]


Epoch: 2
Train Loss: 1.4407, Train Acc: 0.4632
Val Loss: 1.0084, Val Acc: 0.6402



Epoch 3 Training: 100%|██████████| 469/469 [20:07<00:00,  2.57s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.50it/s]


Epoch: 3
Train Loss: 0.8738, Train Acc: 0.6902
Val Loss: 0.7575, Val Acc: 0.7262



Epoch 4 Training: 100%|██████████| 469/469 [20:08<00:00,  2.58s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.50it/s]


Epoch: 4
Train Loss: 0.6615, Train Acc: 0.7702
Val Loss: 0.5298, Val Acc: 0.8235



Epoch 5 Training: 100%|██████████| 469/469 [20:08<00:00,  2.58s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.49it/s]


Epoch: 5
Train Loss: 0.4927, Train Acc: 0.8414
Val Loss: 0.4200, Val Acc: 0.8712



Epoch 6 Training: 100%|██████████| 469/469 [20:10<00:00,  2.58s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.49it/s]


Epoch: 6
Train Loss: 0.4049, Train Acc: 0.8759
Val Loss: 0.3351, Val Acc: 0.9001



Epoch 7 Training: 100%|██████████| 469/469 [20:09<00:00,  2.58s/it]
Validation: 100%|██████████| 79/79 [00:53<00:00,  1.49it/s]


Epoch: 7
Train Loss: 0.3431, Train Acc: 0.8982
Val Loss: 0.2930, Val Acc: 0.9157



Epoch 8 Training: 100%|██████████| 469/469 [20:09<00:00,  2.58s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.49it/s]


Epoch: 8
Train Loss: 0.3024, Train Acc: 0.9103
Val Loss: 0.2970, Val Acc: 0.9100



Epoch 9 Training: 100%|██████████| 469/469 [20:04<00:00,  2.57s/it]
Validation: 100%|██████████| 79/79 [00:52<00:00,  1.51it/s]


Epoch: 9
Train Loss: 0.2697, Train Acc: 0.9199
Val Loss: 0.2505, Val Acc: 0.9260



Epoch 10 Training:  28%|██▊       | 129/469 [05:30<14:29,  2.56s/it]

In [None]:
L_loss = [1.9722, 1.4407, 0.8738, 0.6615, 0.4927, 0.4049, 0.3431, 0.3024, 0.2697, 0.2431]
L_accuracy = [0.2536, 0.4632, 0.6902, 0.7702, 0.8414, 0.8759, 0.8982, 0.9103, 0.9199, 0.9267]
L_val_loss = [1.7660, 1.0084, 0.7575, 0.5298, 0.4200, 0.3351, 0.2930, 0.2970, 0.2505, 0.2628]
L_val_accuracy = [0.3025, 0.6402, 0.7262, 0.8235, 0.8712, 0.9001, 0.9157, 0.9100, 0.9260, 0.9394]

In [None]:
# Assuming you have access to the FastAttention class
from performers_pytorch import FastAttention

# Create an instance of FastAttention for debugging
fast_attention_instance = FastAttention(
    dim_heads=4,
    nb_features=128,  # Adjust as necessary
    ortho_scaling=0,
    causal=False,
    generalized_attention=False,
    kernel_fn=performer_exponential_kernel,
    no_projection=True
)

# Print available methods and attributes
print("Available methods and attributes in FastAttention:")
print(dir(fast_attention_instance))

# Check specifically for the redraw_projection_matrix method
if 'redraw_projection_matrix' in dir(fast_attention_instance):
    print("redraw_projection_matrix method is available.")
else:
    print("redraw_projection_matrix method is NOT available.")

Available methods and attributes in FastAttention:
['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set'