In [None]:
!pip install matplotlib seaborn pandas

In [None]:
import torch.nn as nn
from linear_attention import LinearAttentionFunction
class LinearAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        print(x.shape)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        print(qkv.shape)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        # Apply optimized linear attention
        attn = LinearAttentionFunction.apply(q, k, v)
        print(attn.shape)
        
        x = attn.transpose(1, 2).reshape(B, N, C)
        print(x.shape)
        # x = self.proj(x)
        x = self.proj_drop(x)
        print(x.shape)
        return x

In [None]:
import torch.nn as nn
from linear_attention import LinearAttentionFunction
class LinearAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'The dimension must be divisible by the number of heads'
        
        # Store dimensions for later use
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # QKV projection with proper initialization
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        torch.nn.init.normal_(self.qkv.weight, std=0.02)
        if qkv_bias:
            torch.nn.init.zeros_(self.qkv.bias)

        # Output projection with proper initialization
        self.proj = nn.Linear(dim, dim)
        torch.nn.init.normal_(self.proj.weight, std=0.02)
        torch.nn.init.zeros_(self.proj.bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        # Input should be (batch_size, sequence_length, dimension)
        B, N, C = x.shape
        
        # Verify input dimensions
        assert C == self.dim, f'Input dimension {C} does not match expected {self.dim}'
        
        # Transform input into query, key, value tensors
        # qkv shape: (batch_size, seq_len, 3, num_heads, head_dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        
        # Rearrange dimensions for attention computation
        # Shape: (3, batch_size, num_heads, seq_len, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        
        # Separate query, key, value
        # Each has shape: (batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv.unbind(0)
        
        # Make tensors contiguous for efficient CUDA operations
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        
        # Print shapes for debugging
        print(f"Query shape: {q.shape}")
        print(f"Key shape: {k.shape}")
        print(f"Value shape: {v.shape}")
        
        try:
            # Compute linear attention
            # Expected output shape: (batch_size, num_heads, seq_len, head_dim)
            attn = LinearAttentionFunction.apply(q, k, v)
            
            # Print shape for debugging
            print(f"Attention output shape: {attn.shape}")
            
            # Verify output shape
            expected_shape = (B, self.num_heads, N, self.head_dim)
            assert attn.shape == expected_shape, \
                f"Attention output shape {attn.shape} does not match expected {expected_shape}"
            
            # Reshape to original dimensions
            x = attn.transpose(1, 2).reshape(B, N, C)
            
            # Apply final projections
            x = self.proj(x)
            x = self.proj_drop(x)
            return x
            
        except Exception as e:
            print("\nError in attention computation:")
            print(f"Input shapes:")
            print(f"- Query: {q.shape}")
            print(f"- Key: {k.shape}")
            print(f"- Value: {v.shape}")
            if 'attn' in locals():
                print(f"Attention output shape: {attn.shape}")
            raise RuntimeError("Attention computation failed") from e

In [None]:
def test_shapes(batch_size=32, seq_len=128, dim=256, num_heads=8):
    model = LinearAttention(dim=dim, num_heads=num_heads).cuda()
    x = torch.randn(batch_size, seq_len, dim).cuda()
    try:
        output = model(x)
        print(f"Success! Output shape: {output.shape}")
    except Exception as e:
        print(f"Error: {str(e)}")

# Run test with CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
test_shapes()

In [None]:
def benchmark_attention(
    seq_lengths: List[int],
    batch_size: int = 32,
    dim: int = 256,
    num_heads: int = 8,
    num_warmup: int = 10,
    num_repeats: int = 100,
    device: str = "cuda"
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Benchmark linear attention against traditional attention across different sequence lengths.
    
    Returns:
        Tuple of two DataFrames containing timing results and statistics
    """
    results = []
    
    # Initialize models
    linear_attn = LinearAttention(dim=dim, num_heads=num_heads).to(device)
    trad_attn = TraditionalAttention(dim=dim, num_heads=num_heads).to(device)
    
    for seq_len in seq_lengths:
        print(f"Benchmarking sequence length: {seq_len}")
        
        # Create input tensor
        x = torch.randn(batch_size, seq_len, dim, device=device)
        
        # Warm up GPU
        for _ in range(num_warmup):
            _ = linear_attn(x)
            _ = trad_attn(x)
        
        # Measure linear attention
        torch.cuda.synchronize()
        linear_times = []
        for _ in range(num_repeats):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            
            start.record()
            _ = linear_attn(x)
            end.record()
            
            torch.cuda.synchronize()
            linear_times.append(start.elapsed_time(end))
        
        # Measure traditional attention
        torch.cuda.synchronize()
        trad_times = []
        for _ in range(num_repeats):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            
            start.record()
            _ = trad_attn(x)
            end.record()
            
            torch.cuda.synchronize()
            trad_times.append(start.elapsed_time(end))
            
        # Calculate statistics
        results.append({
            'seq_len': seq_len,
            'linear_mean': np.mean(linear_times),
            'linear_std': np.std(linear_times),
            'trad_mean': np.mean(trad_times),
            'trad_std': np.std(trad_times),
            'speedup': np.mean(trad_times) / np.mean(linear_times)
        })
    
    # Create DataFrames for plotting
    df_results = pd.DataFrame(results)
    
    df_plot = pd.DataFrame([
        {'seq_len': r['seq_len'], 'time': t, 'type': 'Linear Attention'}
        for r in results
        for t in [r['linear_mean']]
    ] + [
        {'seq_len': r['seq_len'], 'time': t, 'type': 'Traditional Attention'}
        for r in results
        for t in [r['trad_mean']]
    ])
    
    return df_results, df_plot

In [None]:
def plot_benchmark_results(df_plot: pd.DataFrame, df_results: pd.DataFrame):
    """Create visualization of benchmark results"""
    plt.figure(figsize=(15, 10))
    
    # Create subplot for timing comparison
    plt.subplot(2, 1, 1)
    sns.lineplot(data=df_plot, x='seq_len', y='time', hue='type', marker='o')
    plt.title('Attention Computation Time vs Sequence Length')
    plt.xlabel('Sequence Length')
    plt.ylabel('Time (ms)')
    plt.grid(True)
    
    # Create subplot for speedup ratio
    plt.subplot(2, 1, 2)
    plt.plot(df_results['seq_len'], df_results['speedup'], marker='o')
    plt.title('Speedup Ratio (Traditional / Linear)')
    plt.xlabel('Sequence Length')
    plt.ylabel('Speedup Factor')
    plt.grid(True)
    
    plt.tight_layout()
    return plt

In [None]:
seq_lengths = [128, 256, 512, 1024, 2048, 4096]
batch_size = 32
dim = 256
num_heads = 8

In [None]:
df_results, df_plot = benchmark_attention(
    seq_lengths=seq_lengths,
    batch_size=batch_size,
    dim=dim,
    num_heads=num_heads
)

In [None]:
print("\nBenchmark Results:")
print(df_results.to_string(index=False))

# Create visualization
plt = plot_benchmark_results(df_plot, df_results)
plt.savefig('attention_benchmark.png')
plt.close()

print("\nBenchmark visualization saved as 'attention_benchmark.png'")