In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the CIFAR-10 dataset with normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [2]:
def apply_axial_rope(embeddings, height, width, dim, trainable=False, theta_h=None, theta_w=None):
    batch_size_heads = embeddings.size(0)
    seq_len = embeddings.size(1)  # Ensure seq_len is derived from embeddings
    quarter_dim = dim // 4

    x1 = embeddings[..., ::2]  # Even-indexed dimensions
    x2 = embeddings[..., 1::2]  # Odd-indexed dimensions

    if trainable:
        assert theta_h is not None and theta_w is not None, "Trainable theta values must be provided."
    else:
        theta_h = torch.tensor([100 ** (-i / quarter_dim) for i in range(quarter_dim)], device=embeddings.device)
        theta_w = torch.tensor([100 ** (-i / quarter_dim) for i in range(quarter_dim)], device=embeddings.device)

    pos_h = torch.arange(height, device=embeddings.device).unsqueeze(1)
    pos_w = torch.arange(width, device=embeddings.device).unsqueeze(1)

    sin_h = torch.sin(pos_h * theta_h)
    cos_h = torch.cos(pos_h * theta_h)
    sin_w = torch.sin(pos_w * theta_w)
    cos_w = torch.cos(pos_w * theta_w)

    rot_matrix_h = torch.zeros(height, quarter_dim, quarter_dim, device=embeddings.device)
    rot_matrix_w = torch.zeros(width, quarter_dim, quarter_dim, device=embeddings.device)

    for i in range(quarter_dim):
        rot_matrix_h[:, i, i] = cos_h[:, i]
        rot_matrix_h[:, (i + 1) % quarter_dim, i] = -sin_h[:, i]
        rot_matrix_h[:, i, (i + 1) % quarter_dim] = sin_h[:, i]
        rot_matrix_h[:, (i + 1) % quarter_dim, (i + 1) % quarter_dim] = cos_h[:, i]

        rot_matrix_w[:, i, i] = cos_w[:, i]
        rot_matrix_w[:, (i + 1) % quarter_dim, i] = -sin_w[:, i]
        rot_matrix_w[:, i, (i + 1) % quarter_dim] = sin_w[:, i]
        rot_matrix_w[:, (i + 1) % quarter_dim, (i + 1) % quarter_dim] = cos_w[:, i]

    x1_h_rotated = torch.einsum('...q,hij->...ij', x1, rot_matrix_h)
    x1_w_rotated = torch.einsum('...q,wij->...ij', x1, rot_matrix_w)
    x2_h_rotated = torch.einsum('...q,hij->...ij', x2, rot_matrix_h)
    x2_w_rotated = torch.einsum('...q,wij->...ij', x2, rot_matrix_w)

    rotated_embeddings = torch.empty(x1_h_rotated.size()[:-1] + (x1_h_rotated.size(-1) * 2,), device=x1_h_rotated.device)
    rotated_embeddings[..., ::2] = x1_h_rotated + x1_w_rotated
    rotated_embeddings[..., 1::2] = x2_h_rotated + x2_w_rotated

    rotated_embeddings = rotated_embeddings.view(batch_size_heads, seq_len, dim)
    return rotated_embeddings

class AxialRoPE(nn.Module):
    def __init__(self, dim):
        super().__init__()
        quarter_dim = dim // 4
        self.theta_h = nn.Parameter(torch.tensor([100 ** (-i / quarter_dim) for i in range(quarter_dim)]))
        self.theta_w = nn.Parameter(torch.tensor([100 ** (-i / quarter_dim) for i in range(quarter_dim)]))

    def forward(self, embeddings, height, width):
        return apply_axial_rope(embeddings, height, width, dim=embeddings.size(-1), trainable=True, theta_h=self.theta_h, theta_w=self.theta_w)


In [None]:
def apply_mixed_rope(embeddings, height, width, dim, theta_h, theta_w):
    batch_size_heads = embeddings.size(0)
    seq_len = embeddings.size(1)  # Ensure seq_len is derived from embeddings

    quarter_dim = dim // 4

    # rename embeddings to x for consistency with axial rope code
    x = embeddings

    pos_h = torch.arange(height, device=embeddings.device).unsqueeze(1)
    pos_w = torch.arange(width, device=embeddings.device).unsqueeze(1)

    # Compute sine and cosine values
    sin_h = torch.sin(pos_h * theta_h)
    cos_h = torch.cos(pos_h * theta_h)
    sin_w = torch.sin(pos_w * theta_w)
    cos_w = torch.cos(pos_w * theta_w)

    # Construct rotational matrices
    rot_matrix_h = torch.zeros(height, quarter_dim, quarter_dim, device=embeddings.device)
    rot_matrix_w = torch.zeros(width, quarter_dim, quarter_dim, device=embeddings.device)

    for i in range(quarter_dim):
        rot_matrix_h[:, i, i] = cos_h[:, i]
        rot_matrix_h[:, (i + 1) % quarter_dim, i] = -sin_h[:, i]
        rot_matrix_h[:, i, (i + 1) % quarter_dim] = sin_h[:, i]
        rot_matrix_h[:, (i + 1) % quarter_dim, (i + 1) % quarter_dim] = cos_h[:, i]

        rot_matrix_w[:, i, i] = cos_w[:, i]
        rot_matrix_w[:, (i + 1) % quarter_dim, i] = -sin_w[:, i]
        rot_matrix_w[:, i, (i + 1) % quarter_dim] = sin_w[:, i]
        rot_matrix_w[:, (i + 1) % quarter_dim, (i + 1) % quarter_dim] = cos_w[:, i]

    # Rotate x1 and x2
    try:
        x_h_rotated = torch.einsum('...q,hij->...ij', x, rot_matrix_h)
        x_w_rotated = torch.einsum('...q,wij->...ij', x, rot_matrix_w)
        rotation_matrix = x_h_rotated + x_w_rotated
    except RuntimeError as e:
        print(f"Error during rotation: {e}")
        raise e

    # Combine and interleave rotated embeddings
    rotated_embeddings = torch.empty(
        x_h_rotated.size()[:-1] + (x_h_rotated.size(-1) * 2,),  # Double the last dimension size
        device=x1_h_rotated.device
    )
    rotated_embeddings = x_h_rotated + x_w_rotated

    # Reshape back to [batch_size_heads, seq_len, dim]
    rotated_embeddings = rotated_embeddings.view(batch_size_heads, seq_len, dim)
    return rotated_embeddings

class RoPEMixed(nn.Module):
    def __init__(self, dim):
        super().__init__()
        quarter_dim = dim // 4
        self.theta_h = nn.Parameter(torch.tensor([100 ** (-i / quarter_dim) for i in range(quarter_dim)]))
        self.theta_w = nn.Parameter(torch.tensor([100 ** (-i / quarter_dim) for i in range(quarter_dim)]))

    def forward(self, embeddings, height, width):
        return apply_axial_rope_mixed(
            embeddings, height, width, dim=embeddings.size(-1),
            theta_h=self.theta_h, theta_w=self.theta_w)

In [3]:
# SelfAttention Module
class SelfAttention(nn.Module):
    def __init__(self, dim, heads=8, height=4, width=4, axialRoPE=False, trainable_theta=False):
        super().__init__()
        self.heads = heads
        self.dim = dim
        self.scale = (dim // heads) ** -0.5
        self.height = height
        self.width = width
        self.axialRoPE = axialRoPE
        self.trainable_theta = trainable_theta
        self.axial_rope = AxialRoPE(dim // heads) if (axialRoPE and trainable_theta) else None

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.fc = nn.Linear(dim, dim)
        

    def forward(self, x):
        batch_size, seq_len, embedding_dim = x.shape
        head_dim = embedding_dim // self.heads

        qkv = self.qkv(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.heads, head_dim)
        q, k, v = qkv.unbind(dim=2)

        q = q.permute(0, 2, 1, 3).reshape(batch_size * self.heads, seq_len, head_dim)
        k = k.permute(0, 2, 1, 3).reshape(batch_size * self.heads, seq_len, head_dim)
        v = v.permute(0, 2, 1, 3).reshape(batch_size * self.heads, seq_len, head_dim)

        if self.axialRoPE:
            if self.trainable_theta:
                q = self.axial_rope(q, self.height, self.width)
                k = self.axial_rope(k, self.height, self.width)
            else:
                q = apply_axial_rope(q, self.height, self.width, head_dim)
                k = apply_axial_rope(k, self.height, self.width, head_dim)

        attn_scores = torch.bmm(q, k.transpose(1, 2)) * self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.bmm(attn_probs, v)
        attn_output = attn_output.view(batch_size, self.heads, seq_len, head_dim)
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embedding_dim)

        return self.fc(attn_output)

class VisionTransformer(nn.Module):
    def __init__(self, image_size=32, patch_size=4, num_classes=10, dim=64, depth=6, heads=8, mlp_dim=128, 
                 axialRoPE=False, trainable_theta=False, basic=False):
        super().__init__()
        self.patch_size = patch_size  
        self.num_patches = (image_size // patch_size) ** 2  # Calculate number of patches
        self.patch_dim = patch_size * patch_size * 3  # Account for RGB channels
        self.axialRoPE = axialRoPE 
        self.trainable_theta = trainable_theta 
        self.patch_embedding = nn.Linear(self.patch_dim, dim)  # Embedding for each patch
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))  # Class token as a learnable parameter
        self.positional_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, dim))  # Positional embeddings

        self.basic = basic  # Use of basic flag to control the inclusion of RoPE or Performers

        if not self.basic:
            # Define transformer blocks with additional features if not basic
            self.transformer_layers = nn.ModuleList([
                TransformerBlock(dim, heads, mlp_dim, self.patch_size, self.patch_size, 
                                 axialRoPE=axialRoPE, trainable_theta=trainable_theta) for _ in range(depth)
            ])
        else:
            # Define simpler transformer blocks without RoPE or Performers
            self.transformer_layers = nn.ModuleList([
                TransformerBlock(dim, heads, mlp_dim, height=0, width=0, 
                                 axialRoPE=False, trainable_theta=False) for _ in range(depth)  
                # Assuming simpler block configuration
            ])

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # Normalization
            nn.Linear(dim, num_classes)  # Output layer
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(batch_size, self.num_patches, -1)
        x = self.patch_embedding(x)

        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x += self.positional_embedding

        for layer in self.transformer_layers:
            x = layer(x)

        return self.mlp_head(x[:, 0])  # Output from the class token position


class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim, height, width, axialRoPE=False, trainable_theta=False):
        """
        Transformer block with optional RoPE.
        - `height` and `width` are ignored if this is a basic transformer (height = width = 0).
        - `trainable`: Determines if the RoPE is trainable.
        """
        super().__init__()
        self.msa_norm = nn.LayerNorm(dim)
        self.axialRoPE = axialRoPE 
        self.trainable_theta = trainable_theta 
        
        self.msa = SelfAttention(dim, heads, height, width, axialRoPE=axialRoPE, trainable_theta=trainable_theta)

        self.mlp_norm = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim)
        )

    def forward(self, x):
        # Apply LayerNorm before the self-attention mechanism
        x = x + self.msa(self.msa_norm(x))

        # Apply a feed-forward network defined in MLP block
        x = x + self.mlp(self.mlp_norm(x))

        return x

In [240]:
# Gram-Schmidt Orthogonalization for generating orthogonal random projections
def gram_schmidt(vectors):
    orthogonalized = []
    for v in vectors:
        for u in orthogonalized:
            v -= torch.dot(v, u) * u
        v = v / v.norm()  # Normalize to unit length
        orthogonalized.append(v)
    return torch.stack(orthogonalized)

# Generate Orthogonal Random Projections
def generate_orthogonal_random_projections(dim, m):
    random_matrix = torch.randn((m, dim), device=device)
    orthogonal_matrix = gram_schmidt(random_matrix)
    return orthogonal_matrix.T  # Return as (dim, m) for projection

# Phi+ kernel implementation
def phi_plus(z, m):
    batch_size, seq_len, dim = z.size()
    norm_squared = torch.norm(z, dim=-1, keepdim=True) ** 2
    orthogonal_random_matrix = generate_orthogonal_random_projections(dim, m)
    projected = z @ orthogonal_random_matrix  # [batch_size, seq_len, m]
    phi_plus_features = torch.exp(-norm_squared / 2) * torch.exp(projected) / torch.sqrt(torch.tensor(m, dtype=z.dtype, device=z.device))
    return phi_plus_features

# Phi++ kernel implementation
def phi_plus_plus(z, m):
    batch_size, seq_len, dim = z.size()
    norm_squared = torch.norm(z, dim=-1, keepdim=True) ** 2
    orthogonal_random_matrix = generate_orthogonal_random_projections(dim, m)
    projected = z @ orthogonal_random_matrix  # [batch_size, seq_len, m]
    phi_plus_plus_features = torch.exp(-norm_squared / 2) * torch.cat([
        torch.exp(projected),
        torch.exp(-projected)
    ], dim=-1) / torch.sqrt(torch.tensor(2 * m, dtype=z.dtype, device=z.device))
    return phi_plus_plus_features

# Performer Self-Attention (Supports Orthogonal Random Features)
class PerformerSelfAttention(nn.Module):
    def __init__(self, dim, heads=8, height=4, width=4,m=8, use_phi_plus_plus=False, axialRoPE=False, trainable_theta=False):
        super().__init__()
        self.heads = heads
        self.dim = dim
        self.height=height
        self.width=width
        self.m = m
        self.scale = (dim // heads) ** -0.5
        self.use_phi_plus_plus = use_phi_plus_plus
        self.axialRoPE = axialRoPE
        self.trainable_theta = trainable_theta
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.fc = nn.Linear(dim, dim)
        # If Axial RoPE and trainable theta are enabled
        if self.axialRoPE and self.trainable_theta:
            self.axial_rope = AxialRoPE(dim // heads)  # Define AxialRoPE with trainable parameters


    def forward(self, x):
        batch_size, seq_len, embedding_dim = x.shape
        print(f'Input shape: {x.shape}')
        initial_head_dim = embedding_dim // self.heads


        qkv = self.qkv(x).view(batch_size, seq_len, 3, self.heads, initial_head_dim)
        q, k, v = qkv.unbind(dim=2)
        print(f'Q, K, V shapes after split: {q.shape}, {k.shape}, {v.shape}')

        q = q.permute(0, 2, 1, 3).reshape(batch_size * self.heads, seq_len, initial_head_dim)
        k = k.permute(0, 2, 1, 3).reshape(batch_size * self.heads, seq_len, initial_head_dim)
        v = v.permute(0, 2, 1, 3).reshape(batch_size * self.heads, seq_len, initial_head_dim)
        
        print(f'Reshaped Q, K, V: {q.shape}, {k.shape}, {v.shape}')
        # Assuming phi_plus_plus was used
        if self.use_phi_plus_plus:
            q_prime = phi_plus_plus(q, self.m)
            k_prime = phi_plus_plus(k, self.m)
            actual_head_dim = q_prime.size(-1)  
        else:
            q_prime = phi_plus(q, self.m)
            k_prime = phi_plus(k, self.m)
            actual_head_dim = q_prime.size(-1) 
        print(f'Q_prime, K_prime after phi: {q_prime.shape}, {k_prime.shape}')
        actual_head_dim = q_prime.size(-1)  # Get the last dimension size after transformations

       # Optional Axial RoPE transformation
        if self.axialRoPE:
            if self.trainable_theta:
                q_prime = self.axial_rope(q_prime, self.height, self.width)  # Applying trainable RoPE to queries
                k_prime = self.axial_rope(k_prime, self.height, self.width)  # Applying trainable RoPE to keys
            else:
                q_prime = apply_axial_rope(q_prime, self.height, self.width, actual_head_dim)  # Applying non-trainable RoPE
                k_prime = apply_axial_rope(k_prime, self.height, self.width, actual_head_dim)
        print(f'Q_prime, K_prime after RoPE: {q_prime.shape}, {k_prime.shape}')
       
        kv = torch.matmul(k_prime.transpose(-2, -1), v)
        qkv = torch.matmul(q_prime, kv)
        print(f'Intermediate KV and QKV shapes: {kv.shape}, {qkv.shape}')

        attn_output = qkv*self.scale
        attn_output = attn_output.view(batch_size, self.heads, seq_len, qkv.size(-1))
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.dim)    
        print(f'Final output shape: {attn_output.shape}')

        return self.fc(attn_output)    
    
    
# Gaussian Performer Self-Attention
import torch.nn.functional as F
class GaussianPerformerSelfAttention(nn.Module):
    def __init__(self, dim, heads=8, height=4, width=4, axialRoPE=False, trainable_theta=False, learnable_G=False):
        super().__init__()
        self.heads = heads
        self.dim = dim
        self.height=height
        self.width=width
        self.scale = (dim // heads) ** -0.5
        self.device = device
        self.axialRoPE = axialRoPE 
        self.trainable_theta = trainable_theta 
        self.learnable_G = learnable_G
        self.head_dim = dim // heads
        if self.axialRoPE and self.trainable_theta:
            self.axial_rope = AxialRoPE(dim // heads)  # Define AxialRoPE with trainable parameters
        if self.learnable_G:
            # Initializing G as a learnable parameter
            self.G = nn.Parameter(torch.randn(self.head_dim, self.head_dim))
        else:
            # Fixed random G
            self.G = torch.randn((self.head_dim, self.head_dim), requires_grad=False)
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.fc = nn.Linear(dim, dim)

    def forward(self, x):
        batch_size, seq_len, embedding_dim = x.shape

        qkv = self.qkv(x).view(batch_size, seq_len, 3, self.heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)

        q = q.reshape(batch_size * self.heads, seq_len, self.head_dim)
        k = k.reshape(batch_size * self.heads, seq_len, self.head_dim)
        v = v.reshape(batch_size * self.heads, seq_len, self.head_dim)  

        q = F.relu(q @ self.G)
        k = F.relu(k @ self.G)
        
        if self.axialRoPE:
            if self.trainable_theta:
                q = self.axial_rope(q, self.height, self.width)
                k = self.axial_rope(k, self.height, self.width)
            else:
                q = apply_axial_rope(q, self.height, self.width, head_dim)
                k = apply_axial_rope(k, self.height, self.width, head_dim)


        attn_scores = torch.bmm(q, k.transpose(1, 2)) * self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.bmm(attn_probs, v).view(batch_size, self.heads, seq_len, self.head_dim)
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embedding_dim)
        return self.fc(attn_output)


# Vision Transformer with Performer Self-Attention
class VisionTransformerPerformer(nn.Module):
    def __init__(self, image_size=32, patch_size=4, num_classes=10, dim=64, depth=6, heads=8, 
                 mlp_dim=128, m=8, 
                 use_phi_plus_plus=False, use_gaussian_performer=False, axialRoPE=False, trainable_theta=False
                 , learnable_G=False ):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size  
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = patch_size * patch_size * 3  # Adjust for 3 color channels
        self.axialRoPE = axialRoPE 
        self.trainable_theta = trainable_theta 
        self.patch_embedding = nn.Linear(self.patch_dim, dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.positional_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, dim))

        self.performer_layers = nn.ModuleList([
            TransformerBlockPerformer(dim, heads, mlp_dim, m, use_phi_plus_plus, use_gaussian_performer
                                     , axialRoPE=axialRoPE, trainable_theta=trainable_theta,
                                     learnable_G=learnable_G) for _ in range(depth)
        ])

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    def forward(self, x):
        batch_size = x.size(0)
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(batch_size, self.num_patches, -1)
        x = self.patch_embedding(x)

        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x += self.positional_embedding

        for layer in self.performer_layers:
            x = layer(x)

        return self.mlp_head(x[:, 0])


class TransformerBlockPerformer(nn.Module):
    def __init__(self, dim, heads, mlp_dim, m, use_phi_plus_plus, use_gaussian_performer, 
                 axialRoPE=False, trainable_theta=False, learnable_G=False):
        """
        Initializes a Performer Transformer block with the option to use phi_plus or phi_plus_plus kernels.
        
        Parameters:
        - dim (int): Dimension of the input features.
        - heads (int): Number of attention heads.
        - mlp_dim (int): Dimension of the feed-forward network.
        - m (int): Number of random features for the Performer kernel.
        - use_phi_plus_plus (bool): Flag to decide between using phi_plus or phi_plus_plus kernels.
        """
        super().__init__()
        self.axialRoPE = axialRoPE 
        self.trainable_theta = trainable_theta 
        self.msa_norm = nn.LayerNorm(dim)  # Layer normalization before the self-attention
        if use_gaussian_performer:
            self.msa = GaussianPerformerSelfAttention(dim, heads, 
                                                      axialRoPE=axialRoPE, trainable_theta=trainable_theta, 
                                                      learnable_G=learnable_G)
        else:
            self.msa = PerformerSelfAttention(dim, heads, m, use_phi_plus_plus=use_phi_plus_plus,
                                              axialRoPE=axialRoPE, trainable_theta=trainable_theta)

        # Feed-forward network
        self.mlp_norm = nn.LayerNorm(dim)  # Layer normalization before the feed-forward network
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),  # First linear layer
            nn.GELU(),                # GELU activation function
            nn.Linear(mlp_dim, dim)   # Second linear layer to bring back to the original dimension
        )

    def forward(self, x):
        """
        Forward pass of the transformer block using Performer attention.
        
        Parameters:
        - x (Tensor): Input tensor of shape [batch_size, seq_len, dim]
        
        Returns:
        - Tensor: Output tensor of the same shape as input after passing through the transformer block.
        """
        # Apply self-attention and add the result to the input (residual connection)
        x = x + self.msa(self.msa_norm(x))

        # Apply the feed-forward network and add the result to the previous output (residual connection)
        x = x + self.mlp(self.mlp_norm(x))

        return x

In [5]:
import time
import os
import numpy as np

# Directory to save the models
save_dir = "./saved_models_CIFAR10"
os.makedirs(save_dir, exist_ok=True)

# Save function
def save_model(model, model_name):
    save_path = os.path.join(save_dir, f"{model_name}.pth")
    torch.save(model.state_dict(), save_path)
    print(f"Model saved at: {save_path}")

# Load function
def load_model(model_class, model_name):
    model = model_class()
    load_path = os.path.join(save_dir, f"{model_name}.pth")
    model.load_state_dict(torch.load(load_path, map_location=device))
    model.to(device)
    print(f"Model loaded from: {load_path}")
    return model

# Evaluation function
def evaluate_model(model, loader, criterion):
    model.eval()
    total_loss, total_correct = 0, 0
    inference_times = []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            start_time = time.time()
            preds = model(x)
            end_time = time.time()
            inference_times.append(end_time - start_time)
            loss = criterion(preds, y)
            total_loss += loss.item()
            total_correct += (preds.argmax(1) == y).sum().item()
    avg_inference_time = np.mean(inference_times)
    accuracy = total_correct / len(loader.dataset)
    return total_loss / len(loader), accuracy, avg_inference_time

# Training function with speed measurement
def train_with_speed(model, loader, optimizer, criterion):
    model.train()
    total_loss, total_correct = 0, 0
    batch_times = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        start_time = time.time()
        optimizer.zero_grad()
        preds = model(x)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        end_time = time.time()
        batch_times.append(end_time - start_time)
        total_loss += loss.item()
        total_correct += (preds.argmax(1) == y).sum().item()
    avg_batch_time = np.mean(batch_times)
    accuracy = total_correct / len(loader.dataset)
    return total_loss / len(loader), accuracy, avg_batch_time

# Evaluation framework
def evaluate_framework(model, model_name, train_loader, test_loader, optimizer, criterion):
    print(f"\nEvaluating {model_name}...")
    train_losses, test_losses = [], []
    train_accuracies, test_accuracies = [], []
    train_times, inference_times = [], []

    for epoch in range(10):
        train_loss, train_acc, train_time = train_with_speed(model, train_loader, optimizer, criterion)
        test_loss, test_acc, avg_inference_time = evaluate_model(model, test_loader, criterion)

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)
        train_times.append(train_time)
        inference_times.append(avg_inference_time)

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Train Time/Batch: {train_time:.4f}s, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, "
              f"Inference Time/Batch: {avg_inference_time:.4f}s")

    # Calculate variance in accuracy
    accuracy_variance = np.var(test_accuracies)

    # Save the model
    save_model(model, model_name)

    # Summary of results
    print(f"\nSummary for {model_name}:")
    print(f"Final Test Accuracy: {test_accuracies[-1]:.4f}")
    print(f"Accuracy Variance: {accuracy_variance:.4f}")
    print(f"Average Training Time per Batch: {np.mean(train_times):.4f}s")
    print(f"Average Inference Time per Batch: {np.mean(inference_times):.4f}s")


# Initialize criterion
criterion = nn.CrossEntropyLoss()

In [224]:
# Basic Vision Transformer (No RoPE, No Performer)
model_basic = VisionTransformer(basic=True).to(device)
optimizer = optim.Adam(model_basic.parameters(), lr=3e-4)
evaluate_framework(model_basic, "vision_transformer_basic", train_loader, test_loader, optimizer, criterion)


Evaluating vision_transformer_basic...
Epoch 1: Train Loss: 1.8009, Train Acc: 0.3362, Train Time/Batch: 0.0480s, Test Loss: 1.5688, Test Acc: 0.4321, Inference Time/Batch: 0.0678s
Epoch 2: Train Loss: 1.4855, Train Acc: 0.4593, Train Time/Batch: 0.0479s, Test Loss: 1.3998, Test Acc: 0.4873, Inference Time/Batch: 0.0675s
Epoch 3: Train Loss: 1.3352, Train Acc: 0.5157, Train Time/Batch: 0.0482s, Test Loss: 1.3519, Test Acc: 0.5096, Inference Time/Batch: 0.0695s
Epoch 4: Train Loss: 1.2283, Train Acc: 0.5576, Train Time/Batch: 0.0484s, Test Loss: 1.2297, Test Acc: 0.5509, Inference Time/Batch: 0.0674s
Epoch 5: Train Loss: 1.1374, Train Acc: 0.5908, Train Time/Batch: 0.0482s, Test Loss: 1.1830, Test Acc: 0.5765, Inference Time/Batch: 0.0682s
Epoch 6: Train Loss: 1.0584, Train Acc: 0.6207, Train Time/Batch: 0.0488s, Test Loss: 1.1697, Test Acc: 0.5820, Inference Time/Batch: 0.0675s
Epoch 7: Train Loss: 0.9979, Train Acc: 0.6417, Train Time/Batch: 0.0482s, Test Loss: 1.1423, Test Acc: 0.59

In [225]:
# Axial RoPE (Non-trainable Theata) with Vision Transformer
model_non_trainable_axial_rope = VisionTransformer(axialRoPE=True, trainable_theta=False).to(device)
optimizer = optim.Adam(model_non_trainable_axial_rope.parameters(), lr=3e-4)
evaluate_framework(model_non_trainable_axial_rope, "vision_transformer_axial_rope_non_trainable_theta", train_loader, test_loader, optimizer, criterion)



Evaluating vision_transformer_axial_rope_non_trainable_theta...
Epoch 1: Train Loss: 2.1849, Train Acc: 0.1857, Train Time/Batch: 0.1656s, Test Loss: 2.1115, Test Acc: 0.2130, Inference Time/Batch: 0.1518s
Epoch 2: Train Loss: 2.0880, Train Acc: 0.2240, Train Time/Batch: 0.1622s, Test Loss: 2.1007, Test Acc: 0.2285, Inference Time/Batch: 0.1450s
Epoch 3: Train Loss: 2.0673, Train Acc: 0.2355, Train Time/Batch: 0.1588s, Test Loss: 2.0479, Test Acc: 0.2445, Inference Time/Batch: 0.1428s
Epoch 4: Train Loss: 2.0282, Train Acc: 0.2497, Train Time/Batch: 0.1559s, Test Loss: 2.0356, Test Acc: 0.2425, Inference Time/Batch: 0.1409s
Epoch 5: Train Loss: 2.0155, Train Acc: 0.2542, Train Time/Batch: 0.1565s, Test Loss: 2.0091, Test Acc: 0.2557, Inference Time/Batch: 0.1410s
Epoch 6: Train Loss: 2.0060, Train Acc: 0.2625, Train Time/Batch: 0.1545s, Test Loss: 2.0127, Test Acc: 0.2508, Inference Time/Batch: 0.1491s
Epoch 7: Train Loss: 2.0104, Train Acc: 0.2578, Train Time/Batch: 0.1545s, Test Los

In [226]:
# Axial RoPE (Trainable Theata) with Vision Transformer
model_trainable_axial_rope = VisionTransformer(axialRoPE=True, trainable_theta=True).to(device)
optimizer = optim.Adam(model_trainable_axial_rope.parameters(), lr=3e-4)
evaluate_framework(model_trainable_axial_rope, "vision_transformer_axial_rope_trainable_theta", train_loader, test_loader, optimizer, criterion)



Evaluating vision_transformer_axial_rope_trainable_theta...
Epoch 1: Train Loss: 2.1830, Train Acc: 0.1879, Train Time/Batch: 0.1848s, Test Loss: 2.0856, Test Acc: 0.2267, Inference Time/Batch: 0.1477s
Epoch 2: Train Loss: 2.0670, Train Acc: 0.2350, Train Time/Batch: 0.1863s, Test Loss: 2.0272, Test Acc: 0.2529, Inference Time/Batch: 0.1479s
Epoch 3: Train Loss: 2.0583, Train Acc: 0.2411, Train Time/Batch: 0.1843s, Test Loss: 2.0497, Test Acc: 0.2454, Inference Time/Batch: 0.1436s
Epoch 4: Train Loss: 2.0218, Train Acc: 0.2557, Train Time/Batch: 0.1839s, Test Loss: 2.0025, Test Acc: 0.2685, Inference Time/Batch: 0.1443s
Epoch 5: Train Loss: 2.0011, Train Acc: 0.2644, Train Time/Batch: 0.1788s, Test Loss: 2.0000, Test Acc: 0.2660, Inference Time/Batch: 0.1403s
Epoch 6: Train Loss: 2.0061, Train Acc: 0.2601, Train Time/Batch: 0.1780s, Test Loss: 2.0076, Test Acc: 0.2614, Inference Time/Batch: 0.1429s
Epoch 7: Train Loss: 2.0194, Train Acc: 0.2546, Train Time/Batch: 0.1828s, Test Loss: 2

In [227]:
# Performer with Phi+ No ROPE
model_performer_phi_plus = VisionTransformerPerformer(use_phi_plus_plus=False).to(device)
optimizer = optim.Adam(model_performer_phi_plus.parameters(), lr=3e-4)
evaluate_framework(model_performer_phi_plus, "vision_performer_phi_plus_noROPE", train_loader, test_loader, optimizer, criterion)



Evaluating vision_performer_phi_plus_noROPE...
Epoch 1: Train Loss: 1.9038, Train Acc: 0.2924, Train Time/Batch: 0.0548s, Test Loss: 1.7706, Test Acc: 0.3471, Inference Time/Batch: 0.0462s
Epoch 2: Train Loss: 1.6280, Train Acc: 0.3970, Train Time/Batch: 0.0550s, Test Loss: 1.5214, Test Acc: 0.4372, Inference Time/Batch: 0.0465s
Epoch 3: Train Loss: 1.4810, Train Acc: 0.4528, Train Time/Batch: 0.0546s, Test Loss: 1.4534, Test Acc: 0.4700, Inference Time/Batch: 0.0469s
Epoch 4: Train Loss: 1.3746, Train Acc: 0.4976, Train Time/Batch: 0.0547s, Test Loss: 1.3515, Test Acc: 0.5110, Inference Time/Batch: 0.0472s
Epoch 5: Train Loss: 1.2775, Train Acc: 0.5373, Train Time/Batch: 0.0547s, Test Loss: 1.2673, Test Acc: 0.5380, Inference Time/Batch: 0.0467s
Epoch 6: Train Loss: 1.2016, Train Acc: 0.5648, Train Time/Batch: 0.0553s, Test Loss: 1.2282, Test Acc: 0.5542, Inference Time/Batch: 0.0464s
Epoch 7: Train Loss: 1.1350, Train Acc: 0.5897, Train Time/Batch: 0.0546s, Test Loss: 1.2031, Test A

In [228]:
# Performer with Phi++ No ROPE
model_performer_phi_plus_plus = VisionTransformerPerformer(use_phi_plus_plus=True).to(device)
optimizer = optim.Adam(model_performer_phi_plus_plus.parameters(), lr=3e-4)
evaluate_framework(model_performer_phi_plus_plus, "vision_performer_phi_plus_plus_noROPE", train_loader, test_loader, optimizer, criterion)



Evaluating vision_performer_phi_plus_plus_noROPE...
Epoch 1: Train Loss: 1.8609, Train Acc: 0.3095, Train Time/Batch: 0.0617s, Test Loss: 1.6650, Test Acc: 0.3903, Inference Time/Batch: 0.0613s
Epoch 2: Train Loss: 1.5708, Train Acc: 0.4209, Train Time/Batch: 0.0616s, Test Loss: 1.4843, Test Acc: 0.4582, Inference Time/Batch: 0.0594s
Epoch 3: Train Loss: 1.4266, Train Acc: 0.4767, Train Time/Batch: 0.0617s, Test Loss: 1.4016, Test Acc: 0.4954, Inference Time/Batch: 0.0592s
Epoch 4: Train Loss: 1.3236, Train Acc: 0.5147, Train Time/Batch: 0.0617s, Test Loss: 1.3252, Test Acc: 0.5211, Inference Time/Batch: 0.0596s
Epoch 5: Train Loss: 1.2353, Train Acc: 0.5520, Train Time/Batch: 0.0618s, Test Loss: 1.2858, Test Acc: 0.5300, Inference Time/Batch: 0.0591s
Epoch 6: Train Loss: 1.1573, Train Acc: 0.5806, Train Time/Batch: 0.0619s, Test Loss: 1.2651, Test Acc: 0.5474, Inference Time/Batch: 0.0595s
Epoch 7: Train Loss: 1.0896, Train Acc: 0.6072, Train Time/Batch: 0.0619s, Test Loss: 1.1879, T

In [231]:
# Axial RoPE（Non-trainable theata） + Performer Phi+
model_performer_phi_plus_axial_rope_nontrain = VisionTransformerPerformer(axialRoPE=True, trainable_theta=False).to(device)
optimizer = optim.Adam(model_performer_phi_plus_axial_rope_nontrain.parameters(), lr=3e-4)
evaluate_framework(model_performer_phi_plus_axial_rope_nontrain, "vision_performer_phi_plus_axial_nontrain_ROPE", train_loader, test_loader, optimizer, criterion)



Evaluating vision_performer_phi_plus_axial_nontrain_ROPE...
Epoch 1: Train Loss: 1.9692, Train Acc: 0.2680, Train Time/Batch: 0.0941s, Test Loss: 1.7996, Test Acc: 0.3319, Inference Time/Batch: 0.0907s
Epoch 2: Train Loss: 1.6552, Train Acc: 0.3861, Train Time/Batch: 0.0941s, Test Loss: 1.5581, Test Acc: 0.4248, Inference Time/Batch: 0.0908s
Epoch 3: Train Loss: 1.5211, Train Acc: 0.4386, Train Time/Batch: 0.0940s, Test Loss: 1.5272, Test Acc: 0.4433, Inference Time/Batch: 0.0900s
Epoch 4: Train Loss: 1.4366, Train Acc: 0.4754, Train Time/Batch: 0.0947s, Test Loss: 1.4260, Test Acc: 0.4808, Inference Time/Batch: 0.0903s
Epoch 5: Train Loss: 1.3644, Train Acc: 0.5043, Train Time/Batch: 0.0941s, Test Loss: 1.3772, Test Acc: 0.4979, Inference Time/Batch: 0.0905s
Epoch 6: Train Loss: 1.3156, Train Acc: 0.5221, Train Time/Batch: 0.0939s, Test Loss: 1.3227, Test Acc: 0.5224, Inference Time/Batch: 0.0927s
Epoch 7: Train Loss: 1.2664, Train Acc: 0.5419, Train Time/Batch: 0.0934s, Test Loss: 1

In [234]:
# Axial RoPE（Trainable theata） + Performer Phi+
model_performer_phi_plus_axial_rope_train = VisionTransformerPerformer(axialRoPE=True, trainable_theta=True ).to(device)
optimizer = optim.Adam(model_performer_phi_plus_axial_rope_train.parameters(), lr=3e-4)
evaluate_framework(model_performer_phi_plus_axial_rope_train, "vision_performer_phi_plus_axial_trained_ROPE", train_loader, test_loader, optimizer, criterion)



Evaluating vision_performer_phi_plus_axial_trained_ROPE...
Epoch 1: Train Loss: 1.9689, Train Acc: 0.2659, Train Time/Batch: 0.1167s, Test Loss: 1.8169, Test Acc: 0.3226, Inference Time/Batch: 0.0909s
Epoch 2: Train Loss: 1.6856, Train Acc: 0.3758, Train Time/Batch: 0.1187s, Test Loss: 1.5975, Test Acc: 0.4163, Inference Time/Batch: 0.0979s
Epoch 3: Train Loss: 1.5419, Train Acc: 0.4292, Train Time/Batch: 0.1189s, Test Loss: 1.5330, Test Acc: 0.4381, Inference Time/Batch: 0.0925s
Epoch 4: Train Loss: 1.4518, Train Acc: 0.4621, Train Time/Batch: 0.1140s, Test Loss: 1.4539, Test Acc: 0.4686, Inference Time/Batch: 0.0888s
Epoch 5: Train Loss: 1.3868, Train Acc: 0.4908, Train Time/Batch: 0.1169s, Test Loss: 1.3940, Test Acc: 0.4891, Inference Time/Batch: 0.0955s
Epoch 6: Train Loss: 1.3258, Train Acc: 0.5160, Train Time/Batch: 0.1154s, Test Loss: 1.3289, Test Acc: 0.5194, Inference Time/Batch: 0.0918s
Epoch 7: Train Loss: 1.2802, Train Acc: 0.5344, Train Time/Batch: 0.1154s, Test Loss: 1.

In [241]:
# Axial RoPE（Non-trainable theata） + Performer Phi++
model_performer_phi_plusplus_axial_rope_nontrain = VisionTransformerPerformer(axialRoPE=True, trainable_theta=False, use_phi_plus_plus=True).to(device)
optimizer = optim.Adam(model_performer_phi_plusplus_axial_rope_nontrain.parameters(), lr=3e-4)
evaluate_framework(model_performer_phi_plusplus_axial_rope_nontrain, "vision_performer_phi_plusplus_axial_nontrain_ROPE", train_loader, test_loader, optimizer, criterion)



Evaluating vision_performer_phi_plusplus_axial_nontrain_ROPE...
Input shape: torch.Size([64, 65, 64])
Q, K, V shapes after split: torch.Size([64, 65, 8, 8]), torch.Size([64, 65, 8, 8]), torch.Size([64, 65, 8, 8])
Reshaped Q, K, V: torch.Size([512, 65, 8]), torch.Size([512, 65, 8]), torch.Size([512, 65, 8])
Q_prime, K_prime after phi: torch.Size([512, 65, 16]), torch.Size([512, 65, 16])


RuntimeError: shape '[512, 65, 16]' is invalid for input of size 1064960

In [None]:
# Axial RoPE（Trainable theata） + Performer Phi++
model_performer_phi_plusplus_axial_rope_train = VisionTransformerPerformer(axialRoPE=True, trainable_theta=True,use_phi_plus_plus=True).to(device)
optimizer = optim.Adam(model_performer_phi_plusplus_axial_rope_train.parameters(), lr=3e-4)
evaluate_framework(model_performer_phi_plusplus_axial_rope_train, "vision_performer_phi_plusplus_axial_trained_ROPE", train_loader, test_loader, optimizer, criterion)


In [None]:
# ReLU with Gaussian matrix, no RoPE
model_ReLU_Gaussian = VisionTransformerPerformer(use_gaussian_performer=True).to(device)
optimizer = optim.Adam(model_ReLU_Gaussian.parameters(), lr=3e-4)
evaluate_framework(model_ReLU_Gaussian, "model_ReLU_Gaussian_no_RoPE", train_loader, test_loader, optimizer, criterion)


In [None]:
# Axial RoPE（Non-trainable theata）+ ReLU with Gaussian matrix
ReLU_Gaussian_axialrope_nontraintheta = VisionTransformerPerformer(use_gaussian_performer=True,axialRoPE=True, trainable_theta=False).to(device)
optimizer = optim.Adam(ReLU_Gaussian_axialrope_nontraintheta.parameters(), lr=3e-4)
evaluate_framework(ReLU_Gaussian_axialrope_nontraintheta, "model_ReLU_Gaussian_axialRoPE_nontrain_theata", train_loader, test_loader, optimizer, criterion)


In [None]:
# Axial RoPE（trainable theata）+ ReLU with Gaussian matrix
model_ReLU_Gaussian_axialrope_nontraintheta = VisionTransformerPerformer(use_gaussian_performer=True,axialRoPE=True, trainable_theta=True).to(device)
optimizer = optim.Adam(model_ReLU_Gaussian_axialrope_nontraintheta.parameters(), lr=3e-4)
evaluate_framework(model_ReLU_Gaussian_axialrope_nontraintheta, "model_ReLU_Gaussian_axialRoPE_trainable_theta", train_loader, test_loader, optimizer, criterion)


In [None]:
# ReLU with Learnable G matrix, no RoPE
model_ReLU_LearnableG = VisionTransformerPerformer(use_gaussian_performer=True, learnable_G=True).to(device)
optimizer = optim.Adam(model_ReLU_LearnableG.parameters(), lr=3e-4)
evaluate_framework(model_ReLU_LearnableG, "model_ReLU_Learnable_G_no_RoPE", train_loader, test_loader, optimizer, criterion)

In [None]:
# Axial RoPE（Non-trainable theata）+ ReLU with Learnable G
ReLU_LearnableG_axialrope_nontraintheta = VisionTransformerPerformer(use_gaussian_performer=True,axialRoPE=True, trainable_theta=False, learnable_G=True).to(device)
optimizer = optim.Adam(ReLU_LearnableG_axialrope_nontraintheta.parameters(), lr=3e-4)
evaluate_framework(ReLU_LearnableG_axialrope_nontraintheta, "model_LearnableG_axialRoPE_nontrain_theata", train_loader, test_loader, optimizer, criterion)


In [None]:
# Axial RoPE（trainable theata）+ ReLU with  with Learnable G
model_ReLU_LearnableG_axialrope_traintheta = VisionTransformerPerformer(use_gaussian_performer=True,axialRoPE=True, trainable_theta=True, learnable_G=True).to(device)
optimizer = optim.Adam(model_ReLU_LearnableG_axialrope_traintheta.parameters(), lr=3e-4)
evaluate_framework(model_ReLU_LearnableG_axialrope_traintheta, "model_ReLU_LearnableG_axialRoPE_trainable_theta", train_loader, test_loader, optimizer, criterion)