In [None]:
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import time
from tqdm import tqdm

In [None]:
# Define the transformations (grayscale, normalization, flattening)
class CustomCIFAR10Dataset(Dataset):
    def __init__(self, cifar_dataset, transform=None, flatten=False):
        """
        Custom dataset wrapper for CIFAR-10.

        Args:
            cifar_dataset: The original CIFAR-10 dataset (loaded with torchvision.datasets).
            transform: Transformations to apply to the images.
            flatten: Whether to flatten the images to 1D.
        """
        self.cifar_dataset = cifar_dataset
        self.transform = transform
        self.flatten = flatten

    def __len__(self):
        return len(self.cifar_dataset)

    def __getitem__(self, idx):
        image, label = self.cifar_dataset[idx]
        if self.transform:
            image = self.transform(image)
        if self.flatten:
            image = image.view(-1)  # Flatten [1, 32, 32] to [1024]
        return image, label


# Define the preprocessing transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),                       # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))         # Normalize to [-1, 1]
])


In [None]:
original_cifar_train = CIFAR10(root='./data', train=True, download=True)
original_cifar_test = CIFAR10(root='./data', train=False, download=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 49.0MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
train_dataset = CustomCIFAR10Dataset(original_cifar_train, transform=transform, flatten=True)
test_dataset = CustomCIFAR10Dataset(original_cifar_test, transform=transform, flatten=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
def segment_means(input_seq, m):
    """Compute landmarks as segment means."""
    batch_size, seq_len = input_seq.shape  # Correctly interpret dimensions
    segment_size = seq_len // m
    assert segment_size > 0, f"Segment size must be greater than 0. seq_len={seq_len}, m={m}"

    # Compute segment means for each sequence in the batch
    landmarks = torch.stack([
        torch.stack([
            input_seq[b, i * segment_size:(i + 1) * segment_size].mean(dim=0)
            for i in range(m)
        ]) for b in range(batch_size)
    ])
    return landmarks

def nystrom_attention(Q, K, V, num_landmarks):
    """Approximate attention using the Nyström method."""
    batch_size, seq_len = Q.shape

    # Normalize Q and K for numerical stability
    Q = Q / (Q.norm(dim=-1, keepdim=True) + 1e-6)
    K = K / (K.norm(dim=-1, keepdim=True) + 1e-6)

    # Step 1: Compute landmarks
    K_landmarks = segment_means(K, num_landmarks)  # Shape: [batch_size, num_landmarks]
    Q_landmarks = segment_means(Q, num_landmarks)  # Shape: [batch_size, num_landmarks]

    # Add a third dimension to allow batch matrix multiplication
    K_landmarks = K_landmarks.unsqueeze(-1)  # Shape: [batch_size, num_landmarks, 1]
    Q_landmarks = Q_landmarks.unsqueeze(-1)  # Shape: [batch_size, num_landmarks, 1]

    # Step 2: Compute scaled attention components
    scale = Q_landmarks.size(-2) ** 0.5  # Scale factor for stability
    A = torch.softmax((Q_landmarks @ K_landmarks.transpose(-2, -1)) / scale, dim=-1)  # Shape: [batch_size, num_landmarks, num_landmarks]
    F = torch.softmax((Q.unsqueeze(-1) @ K_landmarks.transpose(-2, -1)) / scale, dim=-1)  # Shape: [batch_size, seq_len, num_landmarks]
    B = torch.softmax((Q_landmarks @ K.unsqueeze(-1).transpose(-2, -1)) / scale, dim=-1)  # Shape: [batch_size, num_landmarks, seq_len]

    # Step 3: Regularize and compute pseudoinverse
    epsilon = 1e-6
    A = A + epsilon * torch.eye(A.size(-1), device=A.device).unsqueeze(0)  # Avoid in-place operation
    A_pinv = torch.linalg.pinv(A)  # Shape: [batch_size, num_landmarks, num_landmarks]

    # Step 4: Combine matrices
    S_hat = torch.bmm(F, torch.bmm(A_pinv, B))  # Shape: [batch_size, seq_len, seq_len]
    output = torch.bmm(S_hat, V.unsqueeze(-1)).squeeze(-1)  # Shape: [batch_size, seq_len]

    return output

class NystromformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_landmarks):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_landmarks = num_landmarks

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        # Project to Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Apply Nyström attention
        output = nystrom_attention(Q, K, V, self.num_landmarks)

        # Final projection
        return self.out(output)

# Full Model with Classification Head
class NystromformerModel(nn.Module):
    def __init__(self, embed_dim, num_heads, num_landmarks, num_classes):
        super().__init__()
        self.nystromformer = NystromformerLayer(embed_dim, num_heads, num_landmarks)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.nystromformer(x)
        return self.classifier(x)

In [None]:
# Initialize model
model = NystromformerModel(embed_dim=1024, num_heads=4, num_landmarks=64, num_classes=10)

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

NystromformerModel(
  (nystromformer): NystromformerLayer(
    (query): Linear(in_features=1024, out_features=1024, bias=True)
    (key): Linear(in_features=1024, out_features=1024, bias=True)
    (value): Linear(in_features=1024, out_features=1024, bias=True)
    (out): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (classifier): Linear(in_features=1024, out_features=10, bias=True)
)

In [None]:
def train_model(model, train_loader, criterion, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        start_time = time.time()  # Track epoch start time
        total_loss = 0
        correct = 0
        total = 0

        print(f"\nEpoch [{epoch + 1}/{epochs}]")
        with tqdm(total=len(train_loader), desc="Training Progress") as pbar:  # Progress bar
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Backward pass
                optimizer.zero_grad()
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Optimizer step
                optimizer.step()

                # Metrics
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

                # Update progress bar
                pbar.set_postfix({"loss": loss.item(), "accuracy": 100. * correct / total})
                pbar.update(1)

        # Epoch timing
        end_time = time.time()
        epoch_duration = end_time - start_time

        # Print epoch summary
        print(f"Epoch [{epoch + 1}/{epochs}] completed in {epoch_duration:.2f} seconds.")
        print(f"  Loss: {total_loss:.4f}, Accuracy: {100. * correct / total:.2f}%")

In [None]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    print(f"Test Loss: {total_loss:.4f}, Test Accuracy: {correct / total:.4f}")


In [None]:
# Train the model
train_model(model, train_loader, criterion, optimizer, device, epochs=10)



Epoch [1/10]


Training Progress: 100%|██████████| 1563/1563 [14:16<00:00,  1.83it/s, loss=2.07, accuracy=15.4]


Epoch [1/10] completed in 856.39 seconds.
  Loss: 3482.2556, Accuracy: 15.43%

Epoch [2/10]


Training Progress: 100%|██████████| 1563/1563 [14:15<00:00,  1.83it/s, loss=2.33, accuracy=15.3]


Epoch [2/10] completed in 855.68 seconds.
  Loss: 3475.1181, Accuracy: 15.33%

Epoch [3/10]


Training Progress: 100%|██████████| 1563/1563 [14:16<00:00,  1.82it/s, loss=2.19, accuracy=15.4]


Epoch [3/10] completed in 856.70 seconds.
  Loss: 3479.2975, Accuracy: 15.38%

Epoch [4/10]


Training Progress: 100%|██████████| 1563/1563 [14:20<00:00,  1.82it/s, loss=2.48, accuracy=15.3]


Epoch [4/10] completed in 860.07 seconds.
  Loss: 3492.9830, Accuracy: 15.33%

Epoch [5/10]


Training Progress: 100%|██████████| 1563/1563 [14:23<00:00,  1.81it/s, loss=2.29, accuracy=15.5]


Epoch [5/10] completed in 863.59 seconds.
  Loss: 3482.4101, Accuracy: 15.50%

Epoch [6/10]


Training Progress: 100%|██████████| 1563/1563 [14:08<00:00,  1.84it/s, loss=2.1, accuracy=15.5]


Epoch [6/10] completed in 848.98 seconds.
  Loss: 3482.8727, Accuracy: 15.55%

Epoch [7/10]


Training Progress: 100%|██████████| 1563/1563 [14:08<00:00,  1.84it/s, loss=2.2, accuracy=15.3]


Epoch [7/10] completed in 848.84 seconds.
  Loss: 3481.3598, Accuracy: 15.26%

Epoch [8/10]


Training Progress: 100%|██████████| 1563/1563 [14:13<00:00,  1.83it/s, loss=2.31, accuracy=15.3]


Epoch [8/10] completed in 853.09 seconds.
  Loss: 3487.9342, Accuracy: 15.28%

Epoch [9/10]


Training Progress: 100%|██████████| 1563/1563 [14:16<00:00,  1.83it/s, loss=2.26, accuracy=15.6]


Epoch [9/10] completed in 856.16 seconds.
  Loss: 3474.2520, Accuracy: 15.62%

Epoch [10/10]


Training Progress: 100%|██████████| 1563/1563 [14:08<00:00,  1.84it/s, loss=2.22, accuracy=15.3]

Epoch [10/10] completed in 848.76 seconds.
  Loss: 3482.2127, Accuracy: 15.33%





In [None]:
# Evaluate the model
evaluate_model(model, test_loader, criterion, device)


Test Loss: 687.5670, Test Accuracy: 0.1505
