# Modern LLM Architectures

This notebook explores cutting-edge LLM architectures including LLaMA, Mixtral (MoE), Flash Attention, and other recent innovations that have pushed the boundaries of what's possible with language models.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple, Dict, List
import math
from dataclasses import dataclass
import pandas as pd

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Evolution of LLM Architectures

Let's visualize how LLM architectures have evolved over time.

In [None]:
# Timeline of LLM innovations
timeline_data = {
    'Year': [2018, 2019, 2020, 2022, 2023, 2023, 2024],
    'Model': ['GPT-1', 'GPT-2', 'GPT-3', 'ChatGPT', 'GPT-4', 'LLaMA', 'Mixtral'],
    'Parameters': [0.117, 1.5, 175, 175, 1700, 65, 47],  # In billions
    'Key Innovation': [
        'Unsupervised pre-training',
        'Zero-shot transfer',
        'In-context learning',
        'RLHF at scale',
        'Multimodal + MoE',
        'Efficient open model',
        'Open MoE'
    ]
}

df_timeline = pd.DataFrame(timeline_data)

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

# Plot 1: Parameter count over time
ax1.scatter(df_timeline['Year'], df_timeline['Parameters'], 
           s=np.log10(df_timeline['Parameters'] + 1) * 100, 
           alpha=0.6, c=range(len(df_timeline)))

for i, row in df_timeline.iterrows():
    ax1.annotate(row['Model'], (row['Year'], row['Parameters']), 
                xytext=(5, 5), textcoords='offset points', fontsize=10)

ax1.set_yscale('log')
ax1.set_xlabel('Year', fontsize=12)
ax1.set_ylabel('Parameters (Billions)', fontsize=12)
ax1.set_title('Evolution of LLM Scale', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Plot 2: Key innovations timeline
y_positions = range(len(df_timeline))
colors = plt.cm.viridis(np.linspace(0, 1, len(df_timeline)))

for i, row in df_timeline.iterrows():
    ax2.barh(i, 1, left=row['Year']-2018, height=0.8, 
            color=colors[i], alpha=0.7)
    ax2.text(row['Year']-2017.5, i, f"{row['Model']}: {row['Key Innovation']}", 
            va='center', fontsize=10)

ax2.set_xlim(0, 7)
ax2.set_ylim(-0.5, len(df_timeline)-0.5)
ax2.set_xlabel('Years since 2018', fontsize=12)
ax2.set_yticks([])
ax2.set_title('Key Innovations in LLM Development', fontsize=14, fontweight='bold')
ax2.grid(True, axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

## 2. LLaMA Architecture Deep Dive

LLaMA introduced several key innovations that made it more efficient than GPT-3 despite being smaller.

In [None]:
# RMSNorm implementation
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)

# Compare RMSNorm vs LayerNorm
def compare_normalizations():
    hidden_size = 512
    batch_size = 4
    seq_len = 32
    
    # Create random input
    x = torch.randn(batch_size, seq_len, hidden_size)
    
    # Apply normalizations
    rmsnorm = RMSNorm(hidden_size)
    layernorm = nn.LayerNorm(hidden_size)
    
    rms_out = rmsnorm(x)
    ln_out = layernorm(x)
    
    # Visualize distributions
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    # Original distribution
    axes[0, 0].hist(x.flatten().numpy(), bins=50, alpha=0.7, color='gray')
    axes[0, 0].set_title('Original Distribution')
    axes[0, 0].set_ylabel('LayerNorm Row')
    
    axes[1, 0].hist(x.flatten().numpy(), bins=50, alpha=0.7, color='gray')
    axes[1, 0].set_ylabel('RMSNorm Row')
    
    # After normalization
    axes[0, 1].hist(ln_out.flatten().detach().numpy(), bins=50, alpha=0.7, color='blue')
    axes[0, 1].set_title('After Normalization')
    
    axes[1, 1].hist(rms_out.flatten().detach().numpy(), bins=50, alpha=0.7, color='green')
    
    # Statistics
    ln_mean = ln_out.mean(dim=-1).detach()
    ln_std = ln_out.std(dim=-1).detach()
    rms_mean = rms_out.mean(dim=-1).detach()
    rms_std = rms_out.std(dim=-1).detach()
    
    axes[0, 2].scatter(ln_mean.flatten(), ln_std.flatten(), alpha=0.5)
    axes[0, 2].set_title('Mean vs Std')
    axes[0, 2].set_xlabel('Mean')
    axes[0, 2].set_ylabel('Std')
    
    axes[1, 2].scatter(rms_mean.flatten(), rms_std.flatten(), alpha=0.5)
    axes[1, 2].set_xlabel('Mean')
    axes[1, 2].set_ylabel('Std')
    
    plt.tight_layout()
    plt.show()
    
    print("Key differences:")
    print(f"LayerNorm - Mean: {ln_out.mean():.4f}, Std: {ln_out.std():.4f}")
    print(f"RMSNorm - Mean: {rms_out.mean():.4f}, Std: {rms_out.std():.4f}")
    print("\nRMSNorm is ~10% faster as it doesn't compute/subtract mean!")

compare_normalizations()

### Rotary Position Embeddings (RoPE)

In [None]:
def visualize_rope():
    """Visualize how RoPE works."""
    # Simple 2D example for visualization
    dim = 2
    max_len = 16
    base = 10000
    
    # Compute frequencies
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_len).float()
    freqs = torch.einsum('i,j->ij', t, inv_freq)
    
    # Create rotation matrices
    cos_m = freqs.cos()
    sin_m = freqs.sin()
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot 1: Rotation angles
    ax = axes[0, 0]
    positions = np.arange(max_len)
    angles = freqs.numpy()
    ax.plot(positions, angles, marker='o')
    ax.set_xlabel('Position')
    ax.set_ylabel('Rotation Angle (radians)')
    ax.set_title('Rotation Angles by Position')
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Visualize rotations in 2D
    ax = axes[0, 1]
    
    # Original vector
    v = np.array([1, 0])
    
    # Apply rotations for different positions
    for pos in [0, 4, 8, 12]:
        angle = angles[pos, 0]
        cos_a, sin_a = np.cos(angle), np.sin(angle)
        
        # Rotation matrix
        rot_matrix = np.array([[cos_a, -sin_a], [sin_a, cos_a]])
        v_rot = rot_matrix @ v
        
        # Plot
        ax.arrow(0, 0, v_rot[0], v_rot[1], 
                head_width=0.05, head_length=0.05, 
                fc=plt.cm.viridis(pos/12), ec=plt.cm.viridis(pos/12),
                label=f'Pos {pos}')
    
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_title('2D Rotation Visualization')
    
    # Plot 3: Sine and Cosine components
    ax = axes[1, 0]
    ax.plot(positions, cos_m.numpy(), label='cos', marker='o')
    ax.plot(positions, sin_m.numpy(), label='sin', marker='s')
    ax.set_xlabel('Position')
    ax.set_ylabel('Value')
    ax.set_title('Sine and Cosine Components')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Distance preservation
    ax = axes[1, 1]
    
    # Create two vectors at different positions
    v1 = torch.tensor([1.0, 0.5])
    v2 = torch.tensor([0.7, 0.8])
    
    distances = []
    for pos in range(max_len):
        angle = freqs[pos, 0].item()
        cos_a, sin_a = torch.cos(angle), torch.sin(angle)
        
        # Apply rotation to both vectors
        v1_rot = torch.tensor([v1[0] * cos_a - v1[1] * sin_a,
                              v1[0] * sin_a + v1[1] * cos_a])
        v2_rot = torch.tensor([v2[0] * cos_a - v2[1] * sin_a,
                              v2[0] * sin_a + v2[1] * cos_a])
        
        # Compute distance
        dist = torch.norm(v1_rot - v2_rot).item()
        distances.append(dist)
    
    ax.plot(positions, distances, marker='o')
    ax.axhline(y=torch.norm(v1 - v2).item(), color='r', linestyle='--', 
              label='Original distance')
    ax.set_xlabel('Position')
    ax.set_ylabel('Distance between vectors')
    ax.set_title('RoPE Preserves Relative Distances')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("Key insights about RoPE:")
    print("1. Each position gets a unique rotation angle")
    print("2. Relative positions are encoded in the phase difference")
    print("3. Distance between vectors is preserved (rotation is orthogonal)")
    print("4. No learned parameters - purely deterministic")

visualize_rope()

### SwiGLU Activation Function

In [None]:
def compare_activations():
    """Compare different activation functions used in transformers."""
    x = torch.linspace(-3, 3, 1000)
    
    # Different activations
    relu = F.relu(x)
    gelu = F.gelu(x)
    silu = F.silu(x)  # Swish
    
    # SwiGLU is gated - simulate with two inputs
    x1 = x.unsqueeze(1)
    x2 = -x.unsqueeze(1)  # Different linear projection
    swiglu = F.silu(x1) * x2
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot activations
    ax = axes[0, 0]
    ax.plot(x, relu, label='ReLU', linewidth=2)
    ax.plot(x, gelu, label='GELU', linewidth=2)
    ax.plot(x, silu, label='SiLU (Swish)', linewidth=2)
    ax.set_xlabel('Input')
    ax.set_ylabel('Output')
    ax.set_title('Standard Activation Functions')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot derivatives
    ax = axes[0, 1]
    relu_grad = (x > 0).float()
    gelu_grad = torch.autograd.grad(gelu.sum(), x, create_graph=True)[0]
    silu_grad = torch.autograd.grad(silu.sum(), x, create_graph=True)[0]
    
    ax.plot(x.detach(), relu_grad, label='ReLU\'', linewidth=2)
    ax.plot(x.detach(), gelu_grad.detach(), label='GELU\'', linewidth=2)
    ax.plot(x.detach(), silu_grad.detach(), label='SiLU\'', linewidth=2)
    ax.set_xlabel('Input')
    ax.set_ylabel('Gradient')
    ax.set_title('Activation Derivatives')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # SwiGLU visualization
    ax = axes[1, 0]
    im = ax.imshow(swiglu.T, aspect='auto', origin='lower', 
                   extent=[-3, 3, -3, 3], cmap='RdBu_r')
    ax.set_xlabel('x1 (SiLU input)')
    ax.set_ylabel('x2 (Gate input)')
    ax.set_title('SwiGLU: SiLU(x1) * x2')
    plt.colorbar(im, ax=ax)
    
    # Comparison of FFN architectures
    ax = axes[1, 1]
    architectures = ['Standard\n(2 layers)', 'GLU\n(3 matrices)', 'SwiGLU\n(3 matrices)']
    performance = [1.0, 1.08, 1.12]  # Relative performance
    colors = ['blue', 'green', 'orange']
    
    bars = ax.bar(architectures, performance, color=colors, alpha=0.7)
    ax.set_ylabel('Relative Performance')
    ax.set_title('FFN Architecture Comparison')
    ax.set_ylim(0.9, 1.15)
    
    # Add value labels
    for bar, perf in zip(bars, performance):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{perf:.2f}x', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    print("SwiGLU advantages:")
    print("1. Smoother gradients than ReLU")
    print("2. Gating mechanism allows selective information flow")
    print("3. Better performance despite using same parameter count")

compare_activations()

## 3. Mixture of Experts (MoE)

Let's explore how Mixture of Experts enables massive scale while keeping compute manageable.

In [None]:
class SimpleMoE(nn.Module):
    """Simplified Mixture of Experts layer."""
    
    def __init__(self, hidden_size=256, num_experts=8, num_experts_per_tok=2):
        super().__init__()
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_tok
        
        # Router
        self.gate = nn.Linear(hidden_size, num_experts)
        
        # Experts (simple FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size * 4),
                nn.ReLU(),
                nn.Linear(hidden_size * 4, hidden_size)
            ) for _ in range(num_experts)
        ])
        
    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        x_flat = x.view(-1, hidden_size)
        
        # Compute router scores
        router_logits = self.gate(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-k experts
        topk_probs, topk_indices = torch.topk(router_probs, self.num_experts_per_tok, dim=-1)
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)  # Renormalize
        
        # Process tokens through experts
        output = torch.zeros_like(x_flat)
        
        # Track expert usage for visualization
        expert_usage = torch.zeros(self.num_experts)
        
        for i in range(self.num_experts):
            # Find tokens assigned to expert i
            expert_mask = (topk_indices == i).any(dim=-1)
            expert_usage[i] = expert_mask.float().mean()
            
            if expert_mask.any():
                expert_input = x_flat[expert_mask]
                expert_output = self.experts[i](expert_input)
                
                # Get weights for this expert
                weights = topk_probs[expert_mask]
                weights = weights[topk_indices[expert_mask] == i].unsqueeze(-1)
                
                output[expert_mask] += weights * expert_output
        
        output = output.view(batch_size, seq_len, hidden_size)
        
        return output, router_probs.view(batch_size, seq_len, -1), expert_usage

# Demonstrate MoE
def visualize_moe():
    hidden_size = 256
    batch_size = 2
    seq_len = 16
    
    moe = SimpleMoE(hidden_size=hidden_size, num_experts=8, num_experts_per_tok=2)
    
    # Create input with different patterns
    x = torch.randn(batch_size, seq_len, hidden_size)
    
    # Forward pass
    output, router_probs, expert_usage = moe(x)
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Router probabilities heatmap
    ax = axes[0, 0]
    im = ax.imshow(router_probs[0].detach().numpy(), aspect='auto', cmap='YlOrRd')
    ax.set_xlabel('Expert ID')
    ax.set_ylabel('Token Position')
    ax.set_title('Router Probabilities (Batch 0)')
    plt.colorbar(im, ax=ax)
    
    # Plot 2: Expert usage
    ax = axes[0, 1]
    expert_ids = range(moe.num_experts)
    ax.bar(expert_ids, expert_usage.detach().numpy(), alpha=0.7)
    ax.axhline(y=1/moe.num_experts, color='r', linestyle='--', 
              label='Ideal uniform usage')
    ax.set_xlabel('Expert ID')
    ax.set_ylabel('Fraction of Tokens')
    ax.set_title('Expert Usage Distribution')
    ax.legend()
    
    # Plot 3: Top-k expert selection
    ax = axes[1, 0]
    topk_probs, topk_indices = torch.topk(router_probs[0], k=2, dim=-1)
    
    # Create visualization of top-k selection
    selection_matrix = torch.zeros_like(router_probs[0])
    for i in range(seq_len):
        for j, idx in enumerate(topk_indices[i]):
            selection_matrix[i, idx] = topk_probs[i, j]
    
    im = ax.imshow(selection_matrix.detach().numpy(), aspect='auto', cmap='Blues')
    ax.set_xlabel('Expert ID')
    ax.set_ylabel('Token Position')
    ax.set_title('Top-2 Expert Selection')
    plt.colorbar(im, ax=ax)
    
    # Plot 4: MoE vs Dense comparison
    ax = axes[1, 1]
    
    # Compute FLOPs
    dense_flops = hidden_size * hidden_size * 4 * 2  # Two linear layers
    moe_flops = dense_flops / moe.num_experts * moe.num_experts_per_tok
    
    # Parameters
    dense_params = hidden_size * hidden_size * 4 * 2
    moe_params = dense_params * moe.num_experts + hidden_size * moe.num_experts  # + router
    
    categories = ['Parameters', 'FLOPs/token']
    dense_values = [dense_params / 1e6, dense_flops / 1e6]
    moe_values = [moe_params / 1e6, moe_flops / 1e6]
    
    x_pos = np.arange(len(categories))
    width = 0.35
    
    ax.bar(x_pos - width/2, dense_values, width, label='Dense', alpha=0.7)
    ax.bar(x_pos + width/2, moe_values, width, label='MoE (8 experts, top-2)', alpha=0.7)
    
    ax.set_ylabel('Millions')
    ax.set_title('MoE vs Dense FFN Comparison')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(categories)
    ax.legend()
    
    # Add value labels
    for i, (d, m) in enumerate(zip(dense_values, moe_values)):
        ax.text(i - width/2, d + 0.1, f'{d:.1f}M', ha='center')
        ax.text(i + width/2, m + 0.1, f'{m:.1f}M', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    print(f"MoE advantages:")
    print(f"1. {moe_params/dense_params:.1f}x more parameters")
    print(f"2. Only {moe_flops/dense_flops:.1f}x compute per token")
    print(f"3. Specialization: Different experts can learn different patterns")

visualize_moe()

## 4. Flash Attention Concepts

Flash Attention achieves dramatic speedups by being hardware-aware. Let's understand the key concepts.

In [None]:
def visualize_flash_attention_concepts():
    """Visualize the key ideas behind Flash Attention."""
    
    # Memory hierarchy visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Memory hierarchy
    ax = axes[0, 0]
    memory_levels = ['SRAM\n(~20MB)', 'HBM\n(~80GB)', 'CPU RAM\n(~1TB)']
    bandwidth = [19000, 1500, 100]  # GB/s
    latency = [1, 100, 10000]  # Relative
    
    x = np.arange(len(memory_levels))
    ax.bar(x, bandwidth, alpha=0.7, color='blue')
    ax.set_ylabel('Bandwidth (GB/s)', color='blue')
    ax.tick_params(axis='y', labelcolor='blue')
    ax.set_xticks(x)
    ax.set_xticklabels(memory_levels)
    
    ax2 = ax.twinx()
    ax2.plot(x, latency, 'ro-', markersize=10)
    ax2.set_ylabel('Relative Latency', color='red')
    ax2.tick_params(axis='y', labelcolor='red')
    ax2.set_yscale('log')
    
    ax.set_title('GPU Memory Hierarchy')
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Standard vs Flash Attention IO
    ax = axes[0, 1]
    
    seq_lengths = [512, 1024, 2048, 4096, 8192]
    standard_io = [s**2 * 4 / 1e6 for s in seq_lengths]  # O(nÂ²) MB
    flash_io = [s * 4 / 1e3 for s in seq_lengths]  # O(n) MB
    
    ax.plot(seq_lengths, standard_io, 'o-', label='Standard Attention', linewidth=2)
    ax.plot(seq_lengths, flash_io, 's-', label='Flash Attention', linewidth=2)
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('Memory IO (MB)')
    ax.set_title('Memory IO Comparison')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Tiling visualization
    ax = axes[1, 0]
    
    # Create a mock attention matrix
    n = 16
    attention_matrix = np.random.rand(n, n)
    
    # Show tiling
    block_size = 4
    for i in range(0, n, block_size):
        for j in range(0, n, block_size):
            rect = plt.Rectangle((j-0.5, i-0.5), block_size, block_size, 
                               fill=False, edgecolor='red', linewidth=2)
            ax.add_patch(rect)
    
    im = ax.imshow(attention_matrix, cmap='Blues', alpha=0.7)
    ax.set_title('Flash Attention Tiling (4x4 blocks)')
    ax.set_xlabel('Keys')
    ax.set_ylabel('Queries')
    
    # Add annotations
    ax.text(2, 2, 'Block fits\nin SRAM', ha='center', va='center', 
           bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))
    
    # Plot 4: Speedup comparison
    ax = axes[1, 1]
    
    seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
    speedup = [1.5, 2.1, 3.2, 4.8, 7.2, 10.5]  # Approximate speedups
    
    ax.plot(seq_lengths, speedup, 'go-', markersize=10, linewidth=2)
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('Speedup vs Standard Attention')
    ax.set_title('Flash Attention Speedup')
    ax.grid(True, alpha=0.3)
    
    # Add annotations
    for i, (seq, speed) in enumerate(zip(seq_lengths, speedup)):
        if i % 2 == 0:
            ax.annotate(f'{speed:.1f}x', (seq, speed), 
                       xytext=(0, 10), textcoords='offset points', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    print("Flash Attention key innovations:")
    print("1. Tiling: Process attention in blocks that fit in fast SRAM")
    print("2. Recomputation: Trade compute for memory bandwidth")
    print("3. Online softmax: Compute softmax without materializing full matrix")
    print("4. IO-aware: Minimize slow HBM accesses")

visualize_flash_attention_concepts()

## 5. Grouped Query Attention (GQA)

GQA reduces memory and compute by sharing key/value heads across multiple query heads.

In [None]:
def visualize_attention_variants():
    """Compare MHA, MQA, and GQA architectures."""
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Common parameters
    hidden_size = 1024
    num_heads = 8
    
    # Plot 1: Architecture comparison
    ax = axes[0, 0]
    
    # Visual representation of heads
    head_width = 0.8
    head_height = 0.15
    
    # MHA
    y_mha = 2.5
    for i in range(num_heads):
        # Query heads
        rect_q = plt.Rectangle((i, y_mha), head_width, head_height, 
                              facecolor='blue', edgecolor='black', alpha=0.7)
        ax.add_patch(rect_q)
        # Key/Value heads
        rect_kv = plt.Rectangle((i, y_mha - 0.2), head_width, head_height, 
                               facecolor='green', edgecolor='black', alpha=0.7)
        ax.add_patch(rect_kv)
    ax.text(-0.5, y_mha, 'MHA', fontsize=12, va='center')
    
    # MQA
    y_mqa = 1.5
    for i in range(num_heads):
        # Query heads
        rect_q = plt.Rectangle((i, y_mqa), head_width, head_height, 
                              facecolor='blue', edgecolor='black', alpha=0.7)
        ax.add_patch(rect_q)
    # Single KV head
    rect_kv = plt.Rectangle((3.5, y_mqa - 0.2), head_width, head_height, 
                           facecolor='red', edgecolor='black', alpha=0.7)
    ax.add_patch(rect_kv)
    ax.text(-0.5, y_mqa, 'MQA', fontsize=12, va='center')
    
    # GQA
    y_gqa = 0.5
    num_kv_heads = 4
    for i in range(num_heads):
        # Query heads
        rect_q = plt.Rectangle((i, y_gqa), head_width, head_height, 
                              facecolor='blue', edgecolor='black', alpha=0.7)
        ax.add_patch(rect_q)
    for i in range(num_kv_heads):
        # KV heads
        rect_kv = plt.Rectangle((i*2 + 0.5, y_gqa - 0.2), head_width, head_height, 
                               facecolor='orange', edgecolor='black', alpha=0.7)
        ax.add_patch(rect_kv)
    ax.text(-0.5, y_gqa, 'GQA', fontsize=12, va='center')
    
    ax.set_xlim(-1, num_heads)
    ax.set_ylim(0, 3)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('Attention Head Architectures')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='blue', alpha=0.7, label='Query heads'),
        Patch(facecolor='green', alpha=0.7, label='KV heads (MHA)'),
        Patch(facecolor='red', alpha=0.7, label='KV head (MQA)'),
        Patch(facecolor='orange', alpha=0.7, label='KV heads (GQA)')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    # Plot 2: Memory usage comparison
    ax = axes[0, 1]
    
    seq_lengths = [512, 1024, 2048, 4096, 8192]
    
    # KV cache size (in MB)
    head_dim = hidden_size // num_heads
    
    mha_memory = [2 * s * num_heads * head_dim * 4 / 1e6 for s in seq_lengths]
    mqa_memory = [2 * s * 1 * head_dim * 4 / 1e6 for s in seq_lengths]
    gqa_memory = [2 * s * num_kv_heads * head_dim * 4 / 1e6 for s in seq_lengths]
    
    ax.plot(seq_lengths, mha_memory, 'o-', label='MHA (8 heads)', linewidth=2)
    ax.plot(seq_lengths, gqa_memory, 's-', label='GQA (4 KV heads)', linewidth=2)
    ax.plot(seq_lengths, mqa_memory, '^-', label='MQA (1 KV head)', linewidth=2)
    
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('KV Cache Memory (MB)')
    ax.set_title('Memory Usage Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Quality vs efficiency trade-off
    ax = axes[1, 0]
    
    methods = ['MHA', 'GQA-4', 'GQA-2', 'MQA']
    quality = [100, 99.5, 99, 98.5]  # Relative quality
    efficiency = [1, 2, 4, 8]  # Relative efficiency
    
    scatter = ax.scatter(efficiency, quality, s=200, alpha=0.7, c=range(len(methods)))
    
    for i, method in enumerate(methods):
        ax.annotate(method, (efficiency[i], quality[i]), 
                   xytext=(5, 5), textcoords='offset points')
    
    ax.set_xlabel('Relative Efficiency')
    ax.set_ylabel('Relative Quality (%)')
    ax.set_title('Quality vs Efficiency Trade-off')
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Scaling behavior
    ax = axes[1, 1]
    
    model_sizes = ['7B', '13B', '30B', '65B', '175B']
    x_pos = np.arange(len(model_sizes))
    
    # Relative compute requirements
    mha_compute = [1, 1, 1, 1, 1]
    gqa_compute = [0.5, 0.5, 0.5, 0.5, 0.5]
    mqa_compute = [0.125, 0.125, 0.125, 0.125, 0.125]
    
    width = 0.25
    ax.bar(x_pos - width, mha_compute, width, label='MHA', alpha=0.7)
    ax.bar(x_pos, gqa_compute, width, label='GQA', alpha=0.7)
    ax.bar(x_pos + width, mqa_compute, width, label='MQA', alpha=0.7)
    
    ax.set_xlabel('Model Size')
    ax.set_ylabel('Relative Compute (KV operations)')
    ax.set_title('Compute Scaling with Model Size')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(model_sizes)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    print("Key insights:")
    print("1. GQA offers a sweet spot between quality and efficiency")
    print("2. Memory savings are crucial for long sequences")
    print("3. Quality degradation is minimal with careful tuning")
    print("4. Particularly beneficial for inference (KV cache)")

visualize_attention_variants()

## 6. Performance Analysis

Let's analyze the performance characteristics of modern LLM architectures.

In [None]:
def analyze_llm_performance():
    """Analyze performance metrics of modern LLMs."""
    
    # Model specifications
    models = {
        'GPT-3': {'params': 175e9, 'flops': 3.14e23, 'tokens': 300e9},
        'LLaMA-7B': {'params': 7e9, 'flops': 2e22, 'tokens': 1e12},
        'LLaMA-65B': {'params': 65e9, 'flops': 1.4e23, 'tokens': 1.4e12},
        'Mistral-7B': {'params': 7e9, 'flops': 2e22, 'tokens': 8e12},
        'Mixtral-8x7B': {'params': 47e9, 'flops': 7e22, 'tokens': 1e12},
    }
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Parameters vs Performance
    ax = axes[0, 0]
    
    names = list(models.keys())
    params = [models[m]['params'] / 1e9 for m in names]
    # Hypothetical performance scores
    performance = [85, 82, 88, 84, 86]
    
    scatter = ax.scatter(params, performance, s=200, alpha=0.7, c=range(len(names)))
    
    for i, name in enumerate(names):
        ax.annotate(name, (params[i], performance[i]), 
                   xytext=(5, 5), textcoords='offset points', fontsize=9)
    
    ax.set_xscale('log')
    ax.set_xlabel('Parameters (Billions)')
    ax.set_ylabel('Performance Score')
    ax.set_title('Model Size vs Performance')
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Training efficiency
    ax = axes[0, 1]
    
    flops_per_token = [models[m]['flops'] / models[m]['tokens'] for m in names]
    tokens_per_param = [models[m]['tokens'] / models[m]['params'] for m in names]
    
    ax.scatter(tokens_per_param, flops_per_token, s=200, alpha=0.7, c=range(len(names)))
    
    for i, name in enumerate(names):
        ax.annotate(name, (tokens_per_param[i], flops_per_token[i]), 
                   xytext=(5, 5), textcoords='offset points', fontsize=9)
    
    ax.set_xlabel('Tokens per Parameter')
    ax.set_ylabel('FLOPs per Token')
    ax.set_title('Training Efficiency')
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Inference speed comparison
    ax = axes[1, 0]
    
    batch_sizes = [1, 8, 32, 128]
    # Hypothetical throughput (tokens/second)
    throughput = {
        'LLaMA-7B': [50, 350, 1200, 4000],
        'Mistral-7B': [55, 380, 1350, 4500],
        'Mixtral-8x7B': [30, 200, 700, 2400],
    }
    
    for model, values in throughput.items():
        ax.plot(batch_sizes, values, 'o-', label=model, linewidth=2, markersize=8)
    
    ax.set_xlabel('Batch Size')
    ax.set_ylabel('Throughput (tokens/second)')
    ax.set_title('Inference Speed Comparison')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Memory requirements
    ax = axes[1, 1]
    
    # Memory breakdown (GB)
    categories = ['Weights', 'Activations\n(BS=1)', 'KV Cache\n(2K ctx)', 'Total']
    
    # For 7B model
    weights = 7 * 2  # 7B params * 2 bytes (fp16)
    activations = 1  # Rough estimate
    kv_cache = 2 * 2048 * 32 * 128 * 2 / 1e9  # 2 * seq * layers * dim * bytes / 1e9
    total = weights + activations + kv_cache
    
    values_7b = [weights, activations, kv_cache, total]
    
    # For 65B model  
    weights_65b = 65 * 2
    activations_65b = 8
    kv_cache_65b = 2 * 2048 * 80 * 128 * 2 / 1e9
    total_65b = weights_65b + activations_65b + kv_cache_65b
    
    values_65b = [weights_65b, activations_65b, kv_cache_65b, total_65b]
    
    x = np.arange(len(categories))
    width = 0.35
    
    ax.bar(x - width/2, values_7b, width, label='7B Model', alpha=0.7)
    ax.bar(x + width/2, values_65b, width, label='65B Model', alpha=0.7)
    
    ax.set_ylabel('Memory (GB)')
    ax.set_title('Memory Requirements')
    ax.set_xticks(x)
    ax.set_xticklabels(categories)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for i, (v7, v65) in enumerate(zip(values_7b, values_65b)):
        ax.text(i - width/2, v7 + 1, f'{v7:.1f}', ha='center')
        ax.text(i + width/2, v65 + 1, f'{v65:.1f}', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    print("Performance insights:")
    print("1. Mistral-7B achieves similar performance to larger models through better training")
    print("2. MoE models like Mixtral offer more parameters with manageable compute")
    print("3. Memory bandwidth is often the bottleneck, not compute")
    print("4. Efficient architectures are crucial for deployment")

analyze_llm_performance()

## Summary: The Future of LLM Architectures

### Key Innovations We've Explored:

1. **LLaMA Family**:
   - RMSNorm for faster normalization
   - RoPE for better position encoding
   - SwiGLU activation for improved performance
   - GQA for efficient inference

2. **Mixture of Experts**:
   - Sparse computation for massive scale
   - Specialization through routing
   - Load balancing challenges

3. **Flash Attention**:
   - Hardware-aware algorithm design
   - IO optimization over compute optimization
   - Enables longer context lengths

4. **Efficiency Techniques**:
   - GQA/MQA for memory reduction
   - Sliding window attention
   - Better training recipes

### Future Directions:

1. **Even Longer Context**: 1M+ token context windows
2. **More Efficient MoE**: Better routing, dynamic experts
3. **Multimodal Native**: Built-in vision/audio understanding
4. **Continual Learning**: Models that can update without retraining
5. **Edge Deployment**: Efficient models for local inference

The race is on to build models that are not just bigger, but fundamentally better!