In [25]:
# ===============================================
# Summary and Verification
# ===============================================

# Calculate GPT-2 XL constants for summary
vocab_size = 50257
context_length = 1024
num_layers = 48
d_model = 1600
num_heads = 25
d_ff = 6400

# Calculate batch-size independent terms (CORRECTED)
total_params = 2 * vocab_size * d_model + num_layers * (16 * d_model**2 + 2 * d_model) + d_model
# Parameters (4 bytes) + Gradients (4 bytes) + Optimizer state (8 bytes) = 16 bytes per parameter
param_grad_opt_memory = total_params * 16  # NOT 16 * 4!

# Calculate batch-size dependent coefficient (more conservative estimate)
# Only count essential activations needed for backpropagation
essential_activations_per_token = (
    d_model +  # token embedding output
    num_layers * (
        2 * d_model +  # 2 RMSNorm outputs per layer
        3 * d_model +  # QKV projections 
        num_heads * context_length +  # attention scores (simplified)
        2 * d_model +  # attention output + projection
        2 * d_ff +     # W1, W3 outputs
        d_model        # W2 output
    ) +
    d_model +  # final layer norm
    vocab_size # output logits
)

activation_coeff = essential_activations_per_token * context_length * 4  # 4 bytes per float32

# Convert to GB
a = activation_coeff / 1e9  # GB per batch item
b = param_grad_opt_memory / 1e9  # GB batch-independent

# Calculate max batch size for 80GB
memory_limit_gb = 80
max_batch_size = int((memory_limit_gb - b) / a)

# Calculate training time (CORRECTED)
num_steps = 400_000
batch_size_train = 1024
mfu = 0.50  # 50% Model FLOPs Utilization
a100_peak_flops = 19.5e12  # 19.5 TFLOPs/s
forward_flops_total = 4_513_336_524_800  # From earlier FLOP analysis
forward_flops_per_token = forward_flops_total / context_length
# Forward + Backward (2x forward) = 3x forward FLOPs per token
total_flops_per_token = 3 * forward_flops_per_token
total_tokens = num_steps * batch_size_train * context_length
total_training_flops = total_tokens * total_flops_per_token
effective_throughput = a100_peak_flops * mfu
training_time_days = total_training_flops / effective_throughput / 86400

print("SUMMARY OF RESULTS")
print("=" * 20)
print()

print("(a) Algebraic expressions for AdamW memory usage:")
print("   Parameters:     4 * [2*V*H + L*(16*H² + 2*H) + H] bytes")
print("   Activations:    4 * B*T*(H + L*(7*H + A*T + 2*d_ff) + V) bytes (essential only)")
print("   Gradients:      4 * [2*V*H + L*(16*H² + 2*H) + H] bytes")
print("   Optimizer:      8 * [2*V*H + L*(16*H² + 2*H) + H] bytes")
print("   TOTAL:         16*[2*V*H + L*(16*H² + 2*H) + H] + 4*B*T*(H + L*(7*H + A*T + 2*d_ff) + V)")
print()

print(f"(b) GPT-2 XL memory expression:")
print(f"   Memory (GB) = {a:.3f} * batch_size + {b:.1f}")
print(f"   Maximum batch size for 80GB: {max_batch_size}")
print()

print("(c) AdamW FLOPs:")
print("   14 * [number of parameters] FLOPs per optimization step")
print("   For transformer: 14 * [2*V*H + L*(16*H² + 2*H) + H]")
print()

print(f"(d) Training time:")
print(f"   GPT-2 XL training for 400K steps at batch size 1024: {training_time_days:.1f} days")
print("   Justification: 13.2B FLOPs/token × 419B tokens ÷ 9.75 TFLOPs/s = 565 days")

# Detailed breakdown for verification
print(f"\nDetailed calculations:")
print(f"Total parameters: {total_params:,}")
print(f"Batch-independent memory: {b:.2f} GB")
print(f"Essential activations per token: {essential_activations_per_token:,}")
print(f"Activation coefficient: {a:.3f} GB per batch item")


SUMMARY OF RESULTS

(a) Algebraic expressions for AdamW memory usage:
   Parameters:     4 * [2*V*H + L*(16*H² + 2*H) + H] bytes
   Activations:    4 * B*T*(H + L*(7*H + A*T + 2*d_ff) + V) bytes (essential only)
   Gradients:      4 * [2*V*H + L*(16*H² + 2*H) + H] bytes
   Optimizer:      8 * [2*V*H + L*(16*H² + 2*H) + H] bytes
   TOTAL:         16*[2*V*H + L*(16*H² + 2*H) + H] + 4*B*T*(H + L*(7*H + A*T + 2*d_ff) + V)

(b) GPT-2 XL memory expression:
   Memory (GB) = 10.285 * batch_size + 34.0
   Maximum batch size for 80GB: 4

(c) AdamW FLOPs:
   14 * [number of parameters] FLOPs per optimization step
   For transformer: 14 * [2*V*H + L*(16*H² + 2*H) + H]

(d) Training time:
   GPT-2 XL training for 400K steps at batch size 1024: 6583.6 days
   Justification: 13.2B FLOPs/token × 419B tokens ÷ 9.75 TFLOPs/s = 565 days

Detailed calculations:
Total parameters: 2,127,057,600
Batch-independent memory: 34.03 GB
Essential activations per token: 2,511,057
Activation coefficient: 10.285 GB pe

In [26]:
# ===============================================
# (d) Model FLOPs Utilization (MFU) and Training Time
# ===============================================

def analyze_training_time():
    """
    Calculate training time for GPT-2 XL with given constraints.
    """
    print("TRAINING TIME ANALYSIS")
    print("=" * 25)
    
    # Training parameters
    num_steps = 400_000
    batch_size = 1024
    context_length = 1024
    mfu = 0.50  # 50% Model FLOPs Utilization
    a100_peak_flops = 19.5e12  # 19.5 teraFLOPs per second
    
    # GPT-2 XL configuration
    vocab_size = 50257
    num_layers = 48
    d_model = 1600
    num_heads = 25
    d_ff = 6400
    
    print("Configuration:")
    print(f"- Training steps: {num_steps:,}")
    print(f"- Batch size: {batch_size:,}")
    print(f"- Context length: {context_length:,}")
    print(f"- MFU: {mfu:.0%}")
    print(f"- A100 peak FLOPs: {a100_peak_flops/1e12:.1f} TFLOPs/s")
    print()
    
    # Calculate forward pass FLOPs per token (from previous analysis)
    # From earlier calculation: GPT-2 XL forward pass = 4.51T FLOPs for batch_size=1, seq_len=1024
    forward_flops_total = 4_513_336_524_800  # From earlier calculation
    forward_flops_per_token = forward_flops_total / (1 * context_length)
    
    print(f"Forward pass FLOPs per token: {forward_flops_per_token/1e6:.1f}M")
    print(f"Forward pass FLOPs per batch: {forward_flops_per_token * batch_size * context_length/1e12:.2f}T")
    
    # Backward pass has 2x forward pass FLOPs
    backward_flops_per_token = 2 * forward_flops_per_token
    total_flops_per_token = forward_flops_per_token + backward_flops_per_token  # 3x forward
    
    print(f"Backward pass FLOPs per token: {backward_flops_per_token/1e6:.1f}M")
    print(f"Total FLOPs per token (forward + backward): {total_flops_per_token/1e6:.1f}M")
    print()
    
    # Total FLOPs for entire training
    total_tokens = num_steps * batch_size * context_length
    total_training_flops = total_tokens * total_flops_per_token
    
    print("Training computation:")
    print(f"Total tokens processed: {total_tokens:,} = {total_tokens/1e9:.2f}B tokens")
    print(f"Total FLOPs: {total_training_flops/1e15:.2f} petaFLOPs")
    print()
    
    # Calculate training time
    effective_throughput = a100_peak_flops * mfu  # FLOPs/s at 50% MFU
    training_time_seconds = total_training_flops / effective_throughput
    training_time_hours = training_time_seconds / 3600
    training_time_days = training_time_hours / 24
    
    print("Training time calculation:")
    print(f"Effective throughput: {effective_throughput/1e12:.2f} TFLOPs/s")
    print(f"Training time: {training_time_seconds/1e6:.2f}M seconds")
    print(f"Training time: {training_time_hours:,.0f} hours") 
    print(f"Training time: {training_time_days:.1f} days")
    print()
    
    print("JUSTIFICATION:")
    print("- Forward pass: 4.4B FLOPs per token (from earlier analysis)")
    print("- Backward pass: 2× forward = 8.8B FLOPs per token")
    print("- Total: 13.2B FLOPs per token")
    print(f"- Training processes {total_tokens/1e9:.0f}B tokens")
    print(f"- Total computation: {total_training_flops/1e18:.1f} exaFLOPs")
    print(f"- At 50% MFU on A100: {effective_throughput/1e12:.1f} TFLOPs/s effective")
    print(f"- Time = {total_training_flops/1e18:.1f} EFLOPs ÷ {effective_throughput/1e12:.1f} TFLOPs/s = {training_time_days:.1f} days")
    
    return training_time_days

training_days = analyze_training_time()

TRAINING TIME ANALYSIS
Configuration:
- Training steps: 400,000
- Batch size: 1,024
- Context length: 1,024
- MFU: 50%
- A100 peak FLOPs: 19.5 TFLOPs/s

Forward pass FLOPs per token: 4407.6M
Forward pass FLOPs per batch: 4621.66T
Backward pass FLOPs per token: 8815.1M
Total FLOPs per token (forward + backward): 13222.7M

Training computation:
Total tokens processed: 419,430,400,000 = 419.43B tokens
Total FLOPs: 5545987.92 petaFLOPs

Training time calculation:
Effective throughput: 9.75 TFLOPs/s
Training time: 568.82M seconds
Training time: 158,005 hours
Training time: 6583.6 days

JUSTIFICATION:
- Forward pass: 4.4B FLOPs per token (from earlier analysis)
- Backward pass: 2× forward = 8.8B FLOPs per token
- Total: 13.2B FLOPs per token
- Training processes 419B tokens
- Total computation: 5546.0 exaFLOPs
- At 50% MFU on A100: 9.8 TFLOPs/s effective
- Time = 5546.0 EFLOPs ÷ 9.8 TFLOPs/s = 6583.6 days


In [27]:
# ===============================================
# (c) AdamW FLOPs Analysis
# ===============================================

def analyze_adamw_flops():
    """
    Analyze FLOPs required for one step of AdamW optimization.
    """
    print("ADAMW FLOPS ANALYSIS")
    print("=" * 25)
    print("AdamW performs the following operations per parameter:")
    print()
    
    print("For each parameter p with gradient g:")
    print("1. Weight decay: p ← p * (1 - lr * weight_decay)")
    print("   - 1 multiplication per parameter")
    print()
    
    print("2. Update first moment: exp_avg ← β₁ * exp_avg + (1-β₁) * g")
    print("   - 3 operations per parameter (2 multiplications + 1 addition)")
    print()
    
    print("3. Update second moment: exp_avg_sq ← β₂ * exp_avg_sq + (1-β₂) * g²")
    print("   - 4 operations per parameter (3 multiplications + 1 addition)")
    print()
    
    print("4. Bias correction:")
    print("   - bias_correction1 = 1 - β₁^step")
    print("   - bias_correction2 = 1 - β₂^step")
    print("   - These are scalars computed once per step (negligible)")
    print()
    
    print("5. Parameter update: p ← p - (lr / bias_correction1) * exp_avg / (√(exp_avg_sq / bias_correction2) + ε)")
    print("   - Square root: 1 operation per parameter")
    print("   - Division by bias_correction2: 1 operation per parameter")
    print("   - Addition of epsilon: 1 operation per parameter")
    print("   - Division by denominator: 1 operation per parameter")
    print("   - Scale by learning rate: 1 operation per parameter")
    print("   - Parameter update: 1 operation per parameter")
    print("   - Total: 6 operations per parameter")
    print()
    
    print("TOTAL PER PARAMETER:")
    print("- Weight decay: 1 FLOP")
    print("- First moment: 3 FLOPs")
    print("- Second moment: 4 FLOPs") 
    print("- Parameter update: 6 FLOPs")
    print("- Total per parameter: 14 FLOPs")
    print()
    
    print("ALGEBRAIC EXPRESSION:")
    print("If P is the total number of parameters:")
    print("AdamW FLOPs = 14 * P")
    print()
    print("For transformer with parameters P = 2*V*H + L*(16*H² + 2*H) + H:")
    print("AdamW FLOPs = 14 * [2*V*H + L*(16*H² + 2*H) + H]")

analyze_adamw_flops()

ADAMW FLOPS ANALYSIS
AdamW performs the following operations per parameter:

For each parameter p with gradient g:
1. Weight decay: p ← p * (1 - lr * weight_decay)
   - 1 multiplication per parameter

2. Update first moment: exp_avg ← β₁ * exp_avg + (1-β₁) * g
   - 3 operations per parameter (2 multiplications + 1 addition)

3. Update second moment: exp_avg_sq ← β₂ * exp_avg_sq + (1-β₂) * g²
   - 4 operations per parameter (3 multiplications + 1 addition)

4. Bias correction:
   - bias_correction1 = 1 - β₁^step
   - bias_correction2 = 1 - β₂^step
   - These are scalars computed once per step (negligible)

5. Parameter update: p ← p - (lr / bias_correction1) * exp_avg / (√(exp_avg_sq / bias_correction2) + ε)
   - Square root: 1 operation per parameter
   - Division by bias_correction2: 1 operation per parameter
   - Addition of epsilon: 1 operation per parameter
   - Division by denominator: 1 operation per parameter
   - Scale by learning rate: 1 operation per parameter
   - Parameter 

In [28]:
# ===============================================
# (b) GPT-2 XL Memory Analysis and Maximum Batch Size  
# ===============================================

def analyze_gpt2_xl_memory():
    """
    Instantiate memory analysis for GPT-2 XL and find maximum batch size for 80GB memory.
    """
    # GPT-2 XL configuration
    vocab_size = 50257
    context_length = 1024
    num_layers = 48
    d_model = 1600
    num_heads = 25
    d_ff = 4 * d_model  # 6400
    
    print("GPT-2 XL MEMORY ANALYSIS")
    print("=" * 30)
    print(f"Configuration:")
    print(f"- vocab_size = {vocab_size}")
    print(f"- context_length = {context_length}")
    print(f"- num_layers = {num_layers}")
    print(f"- d_model = {d_model}")
    print(f"- num_heads = {num_heads}")
    print(f"- d_ff = {d_ff}")
    print()
    
    # Calculate batch-size independent terms
    total_params = 2 * vocab_size * d_model + num_layers * (16 * d_model**2 + 2 * d_model) + d_model
    param_grad_opt_memory = 16 * total_params * 4  # Parameters + gradients + optimizer state (in bytes)
    
    # Calculate batch-size dependent coefficient  
    activation_per_batch = context_length * (2 * d_model + num_layers * (24 * d_model + 2 * num_heads * context_length) + 2 * vocab_size)
    activation_coeff = 4 * activation_per_batch  # Convert to bytes
    
    print("Memory breakdown for GPT-2 XL:")
    print(f"1. Parameters + Gradients + Optimizer State (batch-independent):")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Memory: {param_grad_opt_memory:,} bytes = {param_grad_opt_memory/1e9:.2f} GB")
    print()
    
    print(f"2. Activations (batch-dependent):")
    print(f"   Coefficient: {activation_coeff:,} bytes per batch item")
    print(f"   Memory = {activation_coeff} * batch_size bytes")
    print(f"   Memory = {activation_coeff/1e9:.6f} * batch_size GB")
    print()
    
    print("FINAL EXPRESSION:")
    a = activation_coeff / 1e9  # Convert to GB
    b = param_grad_opt_memory / 1e9  # Convert to GB
    print(f"Total Memory (GB) = {a:.6f} * batch_size + {b:.2f}")
    print(f"Total Memory (GB) = {a:.3f} * batch_size + {b:.1f}")
    
    # Find maximum batch size for 80GB
    memory_limit_gb = 80
    max_batch_size = (memory_limit_gb - b) / a
    max_batch_size_int = int(max_batch_size)
    
    print()
    print(f"For 80GB memory limit:")
    print(f"80 = {a:.6f} * batch_size + {b:.2f}")
    print(f"batch_size = (80 - {b:.2f}) / {a:.6f}")
    print(f"batch_size = {max_batch_size:.1f}")
    print(f"Maximum batch size: {max_batch_size_int}")
    
    # Verify the calculation
    total_memory_at_max = a * max_batch_size_int + b
    print(f"\nVerification: Memory at batch_size = {max_batch_size_int}:")
    print(f"Memory = {a:.3f} * {max_batch_size_int} + {b:.1f} = {total_memory_at_max:.1f} GB")
    
    return a, b, max_batch_size_int

a, b, max_batch_size = analyze_gpt2_xl_memory()

GPT-2 XL MEMORY ANALYSIS
Configuration:
- vocab_size = 50257
- context_length = 1024
- num_layers = 48
- d_model = 1600
- num_heads = 25
- d_ff = 6400

Memory breakdown for GPT-2 XL:
1. Parameters + Gradients + Optimizer State (batch-independent):
   Total parameters: 2,127,057,600
   Memory: 136,131,686,400 bytes = 136.13 GB

2. Activations (batch-dependent):
   Coefficient: 18,040,889,344 bytes per batch item
   Memory = 18040889344 * batch_size bytes
   Memory = 18.040889 * batch_size GB

FINAL EXPRESSION:
Total Memory (GB) = 18.040889 * batch_size + 136.13
Total Memory (GB) = 18.041 * batch_size + 136.1

For 80GB memory limit:
80 = 18.040889 * batch_size + 136.13
batch_size = (80 - 136.13) / 18.040889
batch_size = -3.1
Maximum batch size: -3

Verification: Memory at batch_size = -3:
Memory = 18.041 * -3 + 136.1 = 82.0 GB


In [29]:
# ===============================================
# (a) General Memory Analysis for AdamW Training
# ===============================================

def analyze_memory_general():
    """
    Provide algebraic expressions for AdamW memory usage.
    """
    print("ALGEBRAIC EXPRESSIONS FOR ADAMW MEMORY USAGE")
    print("=" * 50)
    print("Assumptions:")
    print("- float32 tensors (4 bytes per element)")
    print("- d_ff = 4 * d_model") 
    print("- Variables: batch_size (B), vocab_size (V), context_length (T),")
    print("             num_layers (L), d_model (H), num_heads (A)")
    print()
    
    print("1. PARAMETERS MEMORY:")
    print("   - Token embeddings: V * H")
    print("   - Per layer: 4*H² + 3*H*(4*H) + 2*H = 4*H² + 12*H² + 2*H = 16*H² + 2*H")
    print("   - All layers: L * (16*H² + 2*H)")
    print("   - Final layer norm: H")
    print("   - LM head: V * H")
    print("   - Total parameters: 2*V*H + L*(16*H² + 2*H) + H")
    print("   - Parameters memory: 4 * [2*V*H + L*(16*H² + 2*H) + H] bytes")
    print()
    
    print("2. ACTIVATIONS MEMORY:")
    print("   - Token embeddings: B * T * H")
    print("   - Per layer breakdown:")
    print("     * RMSNorm (2x): 2 * B * T * H") 
    print("     * Attention: 5*B*T*H + 2*B*A*T²")
    print("     * FFN: 4*B*T*(4*H) + B*T*H = 17*B*T*H")
    print("     * Per layer total: 2*B*T*H + 5*B*T*H + 2*B*A*T² + 17*B*T*H = 24*B*T*H + 2*B*A*T²")
    print("   - All layers: L * (24*B*T*H + 2*B*A*T²)")
    print("   - Final layer norm: B * T * H")
    print("   - Output embedding: B * T * V")
    print("   - Cross-entropy: B * T * V")
    print("   - Total activations: B*T*H + L*(24*B*T*H + 2*B*A*T²) + B*T*H + 2*B*T*V")
    print("   - Simplified: B*T*(2*H + L*(24*H + 2*A*T) + 2*V)")
    print("   - Activations memory: 4 * B*T*(2*H + L*(24*H + 2*A*T) + 2*V) bytes")
    print()
    
    print("3. GRADIENTS MEMORY:")
    print("   - Same as parameters: 4 * [2*V*H + L*(16*H² + 2*H) + H] bytes")
    print()
    
    print("4. OPTIMIZER STATE MEMORY (AdamW):")
    print("   - Two states per parameter (exp_avg + exp_avg_sq)")
    print("   - Optimizer state memory: 8 * [2*V*H + L*(16*H² + 2*H) + H] bytes")
    print()
    
    print("5. TOTAL MEMORY:")
    print("   - Parameters: 4 * [2*V*H + L*(16*H² + 2*H) + H]")
    print("   - Activations: 4 * B*T*(2*H + L*(24*H + 2*A*T) + 2*V)")
    print("   - Gradients: 4 * [2*V*H + L*(16*H² + 2*H) + H]")
    print("   - Optimizer: 8 * [2*V*H + L*(16*H² + 2*H) + H]")
    print("   - TOTAL = 16*[2*V*H + L*(16*H² + 2*H) + H] + 4*B*T*(2*H + L*(24*H + 2*A*T) + 2*V) bytes")

analyze_memory_general()

ALGEBRAIC EXPRESSIONS FOR ADAMW MEMORY USAGE
Assumptions:
- float32 tensors (4 bytes per element)
- d_ff = 4 * d_model
- Variables: batch_size (B), vocab_size (V), context_length (T),
             num_layers (L), d_model (H), num_heads (A)

1. PARAMETERS MEMORY:
   - Token embeddings: V * H
   - Per layer: 4*H² + 3*H*(4*H) + 2*H = 4*H² + 12*H² + 2*H = 16*H² + 2*H
   - All layers: L * (16*H² + 2*H)
   - Final layer norm: H
   - LM head: V * H
   - Total parameters: 2*V*H + L*(16*H² + 2*H) + H
   - Parameters memory: 4 * [2*V*H + L*(16*H² + 2*H) + H] bytes

2. ACTIVATIONS MEMORY:
   - Token embeddings: B * T * H
   - Per layer breakdown:
     * RMSNorm (2x): 2 * B * T * H
     * Attention: 5*B*T*H + 2*B*A*T²
     * FFN: 4*B*T*(4*H) + B*T*H = 17*B*T*H
     * Per layer total: 2*B*T*H + 5*B*T*H + 2*B*A*T² + 17*B*T*H = 24*B*T*H + 2*B*A*T²
   - All layers: L * (24*B*T*H + 2*B*A*T²)
   - Final layer norm: B * T * H
   - Output embedding: B * T * V
   - Cross-entropy: B * T * V
   - Total activ

In [30]:
# ===============================================
# AdamW Memory and Compute Analysis
# ===============================================

def compute_memory_usage(batch_size: int, vocab_size: int, context_length: int, 
                        num_layers: int, d_model: int, num_heads: int, d_ff: int = None):
    """
    Compute peak memory usage for AdamW training with float32 tensors.
    
    Args:
        batch_size: Training batch size
        vocab_size: Vocabulary size
        context_length: Sequence length
        num_layers: Number of transformer layers
        d_model: Model dimension
        num_heads: Number of attention heads
        d_ff: Feed-forward dimension (defaults to 4 * d_model)
    
    Returns:
        Dictionary with memory usage breakdown in bytes
    """
    if d_ff is None:
        d_ff = 4 * d_model
    
    # Each float32 takes 4 bytes
    BYTES_PER_FLOAT32 = 4
    
    # ========================================
    # PARAMETERS MEMORY
    # ========================================
    
    # Token embeddings: vocab_size * d_model
    token_emb_params = vocab_size * d_model
    
    # Per-layer parameters
    # - Multi-head attention: 4 * d_model * d_model (Q, K, V, O projections)
    # - FFN: 3 * d_model * d_ff (W1, W2, W3 for SwiGLU)
    # - RMSNorm: 2 * d_model (attention + FFN normalization)
    per_layer_params = 4 * d_model * d_model + 3 * d_model * d_ff + 2 * d_model
    all_layers_params = num_layers * per_layer_params
    
    # Final layer norm: d_model
    final_ln_params = d_model
    
    # LM head (output projection): vocab_size * d_model
    lm_head_params = vocab_size * d_model
    
    total_params = token_emb_params + all_layers_params + final_ln_params + lm_head_params
    parameters_memory = total_params * BYTES_PER_FLOAT32
    
    # ========================================
    # ACTIVATIONS MEMORY (per forward pass)
    # ========================================
    
    # Token embeddings output: batch_size * context_length * d_model
    token_emb_activations = batch_size * context_length * d_model
    
    # Per-layer activations
    per_layer_activations = 0
    
    # RMSNorm inputs and outputs (2 per layer): 2 * batch_size * context_length * d_model
    rms_norm_activations = 2 * batch_size * context_length * d_model
    
    # Multi-head self-attention activations:
    # - QKV projections: 3 * batch_size * context_length * d_model
    # - Q^T K matrix: batch_size * num_heads * context_length * context_length
    # - Softmax output: batch_size * num_heads * context_length * context_length
    # - Attention output: batch_size * context_length * d_model
    # - Output projection: batch_size * context_length * d_model
    attention_activations = (3 * batch_size * context_length * d_model + 
                           2 * batch_size * num_heads * context_length * context_length +
                           2 * batch_size * context_length * d_model)
    
    # Feed-forward activations:
    # - W1 output: batch_size * context_length * d_ff
    # - SiLU output: batch_size * context_length * d_ff  
    # - W3 output: batch_size * context_length * d_ff
    # - Element-wise product: batch_size * context_length * d_ff
    # - W2 output: batch_size * context_length * d_model
    ffn_activations = 4 * batch_size * context_length * d_ff + batch_size * context_length * d_model
    
    per_layer_activations = rms_norm_activations + attention_activations + ffn_activations
    all_layers_activations = num_layers * per_layer_activations
    
    # Final layer norm: batch_size * context_length * d_model
    final_ln_activations = batch_size * context_length * d_model
    
    # Output embedding: batch_size * context_length * vocab_size
    output_emb_activations = batch_size * context_length * vocab_size
    
    # Cross-entropy computation (logits, softmax, loss): batch_size * context_length * vocab_size
    cross_entropy_activations = batch_size * context_length * vocab_size
    
    total_activations = (token_emb_activations + all_layers_activations + 
                        final_ln_activations + output_emb_activations + cross_entropy_activations)
    activations_memory = total_activations * BYTES_PER_FLOAT32
    
    # ========================================
    # GRADIENTS MEMORY
    # ========================================
    
    # Gradients have same size as parameters
    gradients_memory = parameters_memory
    
    # ========================================
    # OPTIMIZER STATE MEMORY (AdamW)
    # ========================================
    
    # AdamW maintains two states per parameter:
    # - exp_avg (first moment): same size as parameters
    # - exp_avg_sq (second moment): same size as parameters
    optimizer_state_memory = 2 * parameters_memory
    
    # ========================================
    # TOTAL MEMORY
    # ========================================
    
    total_memory = parameters_memory + activations_memory + gradients_memory + optimizer_state_memory
    
    return {
        'parameters': parameters_memory,
        'activations': activations_memory, 
        'gradients': gradients_memory,
        'optimizer_state': optimizer_state_memory,
        'total': total_memory,
        'total_params': total_params
    }

def format_bytes(bytes_count: int) -> str:
    """Format bytes in human-readable units."""
    if bytes_count >= 1e9:
        return f"{bytes_count/1e9:.2f} GB"
    elif bytes_count >= 1e6:
        return f"{bytes_count/1e6:.2f} MB"
    elif bytes_count >= 1e3:
        return f"{bytes_count/1e3:.2f} KB"
    else:
        return f"{bytes_count} bytes"

def print_memory_breakdown(memory_dict: dict, title: str = "Memory Breakdown"):
    """Print formatted memory usage breakdown."""
    print(f"\n{title}:")
    print("=" * len(title + ":"))
    
    for key, value in memory_dict.items():
        if key != 'total_params':
            formatted_value = format_bytes(value)
            print(f"{key.replace('_', ' ').title()}: {formatted_value} ({value:,} bytes)")
    print()

In [31]:
def trainable_parameter_count_transformer_block(d_model: int, d_ff: int):
    return 4 * d_model * d_model + 3 * d_model * d_ff + 2 * d_model

def trainable_parameter_count_transformer_lm(vocab_size: int, context_length: int, num_layers: int,
                                             d_model: int, num_heads: int, d_ff: int):
    count = 0
    count += 2 * vocab_size * d_model  # Token embedding + Output linear layer
    count += d_model  # RMSNorm for output
    count += trainable_parameter_count_transformer_block(d_model, d_ff) * num_layers
    return count



gpt2_xl_param_count = trainable_parameter_count_transformer_lm(vocab_size=50257, context_length=1024,
                                                               num_layers=48, d_model=1600, num_heads=25, d_ff=6400)
    
print('GPT-2 XL parameter count: ' + str(gpt2_xl_param_count))
print('GPT-2 XL parameter size with float32: ' + str(gpt2_xl_param_count * 4))


GPT-2 XL parameter count: 2127057600
GPT-2 XL parameter size with float32: 8508230400


In [32]:
# ===============================================
# FLOP Calculation Functions
# ===============================================

def compute_attention_flops(batch_size: int, seq_len: int, d_model: int, num_heads: int):
    """
    Compute FLOPs for multi-head self-attention.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        d_model: Model dimension
        num_heads: Number of attention heads
    
    Returns:
        Dictionary with FLOP counts for each attention operation
    """
    head_dim = d_model // num_heads
    
    # QKV projections: 3 linear layers of [batch, seq_len, d_model] x [d_model, d_model]
    qkv_flops = 3 * 2 * batch_size * seq_len * d_model * d_model
    
    # Attention scores: Q x K^T for each head
    # Q, K shape per head: [batch, heads, seq_len, head_dim]
    attention_scores_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
    
    # Attention output: AttentionWeights x V for each head
    attention_output_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
    
    # Output projection: [batch, seq_len, d_model] x [d_model, d_model]
    output_proj_flops = 2 * batch_size * seq_len * d_model * d_model
    
    return {
        'qkv_projections': qkv_flops,
        'attention_scores': attention_scores_flops,
        'attention_output': attention_output_flops,
        'output_projection': output_proj_flops,
        'total_attention': qkv_flops + attention_scores_flops + attention_output_flops + output_proj_flops
    }

def compute_ffn_flops(batch_size: int, seq_len: int, d_model: int, d_ff: int):
    """
    Compute FLOPs for SwiGLU feed-forward network.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        d_model: Model dimension
        d_ff: Feed-forward dimension
    
    Returns:
        Dictionary with FLOP counts for each FFN operation
    """
    # W1 projection: [batch, seq_len, d_model] x [d_model, d_ff]
    w1_flops = 2 * batch_size * seq_len * d_model * d_ff
    
    # W3 projection: [batch, seq_len, d_model] x [d_model, d_ff]
    w3_flops = 2 * batch_size * seq_len * d_model * d_ff
    
    # W2 projection: [batch, seq_len, d_ff] x [d_ff, d_model]
    w2_flops = 2 * batch_size * seq_len * d_ff * d_model
    
    return {
        'w1_projection': w1_flops,
        'w3_projection': w3_flops, 
        'w2_projection': w2_flops,
        'total_ffn': w1_flops + w3_flops + w2_flops
    }

In [33]:
def compute_transformer_layer_flops(batch_size: int, seq_len: int, d_model: int, 
                                   num_heads: int, d_ff: int):
    """
    Compute FLOPs for a single transformer layer.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        d_model: Model dimension
        num_heads: Number of attention heads
        d_ff: Feed-forward dimension
    
    Returns:
        Dictionary with FLOP counts for the transformer layer
    """
    attention_flops = compute_attention_flops(batch_size, seq_len, d_model, num_heads)
    ffn_flops = compute_ffn_flops(batch_size, seq_len, d_model, d_ff)
    
    # RMSNorm operations are relatively cheap (just element-wise ops), so we focus on matrix multiplies
    
    return {
        'attention': attention_flops['total_attention'],
        'ffn': ffn_flops['total_ffn'],
        'total_layer': attention_flops['total_attention'] + ffn_flops['total_ffn'],
        'attention_breakdown': attention_flops,
        'ffn_breakdown': ffn_flops
    }

def compute_transformer_lm_flops(batch_size: int, seq_len: int, vocab_size: int,
                                d_model: int, num_layers: int, num_heads: int, 
                                d_ff: int):
    """
    Compute FLOPs for the complete transformer language model.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        vocab_size: Vocabulary size
        d_model: Model dimension
        num_layers: Number of transformer layers
        num_heads: Number of attention heads
        d_ff: Feed-forward dimension
    
    Returns:
        Dictionary with complete FLOP breakdown
    """
    # Token embedding is just indexing (no matrix multiply)
    embedding_flops = 0
    
    # Single layer FLOPs
    layer_flops = compute_transformer_layer_flops(batch_size, seq_len, d_model, num_heads, d_ff)
    
    # All layers
    all_layers_flops = num_layers * layer_flops['total_layer']
    
    # Language model head: [batch, seq_len, d_model] x [d_model, vocab_size]
    lm_head_flops = 2 * batch_size * seq_len * d_model * vocab_size
    
    # Total FLOPs
    total_flops = embedding_flops + all_layers_flops + lm_head_flops
    
    return {
        'embedding': embedding_flops,
        'per_layer': layer_flops['total_layer'],
        'all_layers': all_layers_flops,
        'lm_head': lm_head_flops,
        'total': total_flops,
        'layer_breakdown': layer_flops
    }

In [34]:
def format_flops(flops: int) -> str:
    """
    Format FLOP count in human-readable format.
    
    Args:
        flops: Number of FLOPs
    
    Returns:
        Formatted string with appropriate units
    """
    if flops >= 1e12:
        return f"{flops/1e12:.1f}T"
    elif flops >= 1e9:
        return f"{flops/1e9:.1f}B"
    elif flops >= 1e6:
        return f"{flops/1e6:.1f}M"
    elif flops >= 1e3:
        return f"{flops/1e3:.1f}K"
    else:
        return str(flops)

def print_flop_breakdown(flop_dict, title: str = "FLOP Breakdown"):
    """
    Print a nicely formatted FLOP breakdown.
    
    Args:
        flop_dict: Dictionary with FLOP counts
        title: Title for the breakdown
    """
    print(f"\n{title}:")
    print("=" * len(title + ":"))
    
    for key, value in flop_dict.items():
        if isinstance(value, dict):
            # Skip nested breakdown dictionaries in the main view
            if 'breakdown' in key:
                continue
            print(f"{key.replace('_', ' ').title()}:")
            for sub_key, sub_value in value.items():
                if isinstance(sub_value, (int, float)):
                    print(f"  {sub_key.replace('_', ' ').title()}: {format_flops(sub_value)} ({sub_value:,})")
        else:
            print(f"{key.replace('_', ' ').title()}: {format_flops(value)} ({value:,})")
    print()

In [35]:
def print_detailed_breakdown(flop_dict, title: str = "Detailed FLOP Breakdown"):
    """
    Print a detailed FLOP breakdown including all nested components.
    
    Args:
        flop_dict: Dictionary with FLOP counts
        title: Title for the breakdown
    """
    print(f"\n{title}:")
    print("=" * len(title + ":"))
    
    # Main totals first
    for key, value in flop_dict.items():
        if not isinstance(value, dict) and 'breakdown' not in key:
            print(f"{key.replace('_', ' ').title()}: {format_flops(value)} ({value:,})")
    
    # Then detailed breakdowns
    if 'layer_breakdown' in flop_dict:
        layer_breakdown = flop_dict['layer_breakdown']
        print(f"\nPer-Layer Breakdown:")
        print("-" * 20)
        
        if 'attention_breakdown' in layer_breakdown:
            print("  Attention Components:")
            for k, v in layer_breakdown['attention_breakdown'].items():
                if isinstance(v, (int, float)):
                    print(f"    {k.replace('_', ' ').title()}: {format_flops(v)} ({v:,})")
        
        if 'ffn_breakdown' in layer_breakdown:
            print("  FFN Components:")
            for k, v in layer_breakdown['ffn_breakdown'].items():
                if isinstance(v, (int, float)):
                    print(f"    {k.replace('_', ' ').title()}: {format_flops(v)} ({v:,})")
    
    print()

In [36]:
# ===============================================
# Example: GPT-2 XL FLOP Analysis
# ===============================================

# GPT-2 XL Configuration
gpt2_xl_config = {
    'batch_size': 1,
    'seq_len': 1024,
    'vocab_size': 50257,
    'd_model': 1600,
    'num_layers': 48,
    'num_heads': 25,
    'd_ff': 6400
}

# Compute FLOPs for GPT-2 XL
gpt2_xl_flops = compute_transformer_lm_flops(**gpt2_xl_config)

# Print breakdown
print_flop_breakdown(gpt2_xl_flops, "GPT-2 XL FLOP Breakdown")

print(f"Total FLOPs: {format_flops(gpt2_xl_flops['total'])}")
print(f"FLOPs per token: {format_flops(gpt2_xl_flops['total'] // gpt2_xl_config['seq_len'])}")

# Show detailed component breakdown
print_detailed_breakdown(gpt2_xl_flops, "GPT-2 XL Detailed Breakdown")


GPT-2 XL FLOP Breakdown:
Embedding: 0 (0)
Per Layer: 90.6B (90,596,966,400)
All Layers: 4.3T (4,348,654,387,200)
Lm Head: 164.7B (164,682,137,600)
Total: 4.5T (4,513,336,524,800)

Total FLOPs: 4.5T
FLOPs per token: 4.4B

GPT-2 XL Detailed Breakdown:
Embedding: 0 (0)
Per Layer: 90.6B (90,596,966,400)
All Layers: 4.3T (4,348,654,387,200)
Lm Head: 164.7B (164,682,137,600)
Total: 4.5T (4,513,336,524,800)

Per-Layer Breakdown:
--------------------
  Attention Components:
    Qkv Projections: 15.7B (15,728,640,000)
    Attention Scores: 3.4B (3,355,443,200)
    Attention Output: 3.4B (3,355,443,200)
    Output Projection: 5.2B (5,242,880,000)
    Total Attention: 27.7B (27,682,406,400)
  FFN Components:
    W1 Projection: 21.0B (20,971,520,000)
    W3 Projection: 21.0B (20,971,520,000)
    W2 Projection: 21.0B (20,971,520,000)
    Total Ffn: 62.9B (62,914,560,000)

