In [10]:
def trainable_parameter_count_transformer_block(d_model: int, d_ff: int):
    return 4 * d_model * d_model + 3 * d_model * d_ff + 2 * d_model

def trainable_parameter_count_transformer_lm(vocab_size: int, context_length: int, num_layers: int,
                                             d_model: int, num_heads: int, d_ff: int):
    count = 0
    count += 2 * vocab_size * d_model  # Token embedding + Output linear layer
    count += d_model  # RMSNorm for output
    count += trainable_parameter_count_transformer_block(d_model, d_ff) * num_layers
    return count



gpt2_xl_param_count = trainable_parameter_count_transformer_lm(vocab_size=50257, context_length=1024,
                                                               num_layers=48, d_model=1600, num_heads=25, d_ff=6400)
    
print('GPT-2 XL parameter count: ' + str(gpt2_xl_param_count))
print('GPT-2 XL parameter size with float32: ' + str(gpt2_xl_param_count * 4))


GPT-2 XL parameter count: 2127057600
GPT-2 XL parameter size with float32: 8508230400


In [11]:
# ===============================================
# FLOP Calculation Functions
# ===============================================

def compute_attention_flops(batch_size: int, seq_len: int, d_model: int, num_heads: int):
    """
    Compute FLOPs for multi-head self-attention.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        d_model: Model dimension
        num_heads: Number of attention heads
    
    Returns:
        Dictionary with FLOP counts for each attention operation
    """
    head_dim = d_model // num_heads
    
    # QKV projections: 3 linear layers of [batch, seq_len, d_model] x [d_model, d_model]
    qkv_flops = 3 * 2 * batch_size * seq_len * d_model * d_model
    
    # Attention scores: Q x K^T for each head
    # Q, K shape per head: [batch, heads, seq_len, head_dim]
    attention_scores_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
    
    # Attention output: AttentionWeights x V for each head
    attention_output_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
    
    # Output projection: [batch, seq_len, d_model] x [d_model, d_model]
    output_proj_flops = 2 * batch_size * seq_len * d_model * d_model
    
    return {
        'qkv_projections': qkv_flops,
        'attention_scores': attention_scores_flops,
        'attention_output': attention_output_flops,
        'output_projection': output_proj_flops,
        'total_attention': qkv_flops + attention_scores_flops + attention_output_flops + output_proj_flops
    }

def compute_ffn_flops(batch_size: int, seq_len: int, d_model: int, d_ff: int):
    """
    Compute FLOPs for SwiGLU feed-forward network.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        d_model: Model dimension
        d_ff: Feed-forward dimension
    
    Returns:
        Dictionary with FLOP counts for each FFN operation
    """
    # W1 projection: [batch, seq_len, d_model] x [d_model, d_ff]
    w1_flops = 2 * batch_size * seq_len * d_model * d_ff
    
    # W3 projection: [batch, seq_len, d_model] x [d_model, d_ff]
    w3_flops = 2 * batch_size * seq_len * d_model * d_ff
    
    # W2 projection: [batch, seq_len, d_ff] x [d_ff, d_model]
    w2_flops = 2 * batch_size * seq_len * d_ff * d_model
    
    return {
        'w1_projection': w1_flops,
        'w3_projection': w3_flops, 
        'w2_projection': w2_flops,
        'total_ffn': w1_flops + w3_flops + w2_flops
    }

In [12]:
def compute_transformer_layer_flops(batch_size: int, seq_len: int, d_model: int, 
                                   num_heads: int, d_ff: int):
    """
    Compute FLOPs for a single transformer layer.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        d_model: Model dimension
        num_heads: Number of attention heads
        d_ff: Feed-forward dimension
    
    Returns:
        Dictionary with FLOP counts for the transformer layer
    """
    attention_flops = compute_attention_flops(batch_size, seq_len, d_model, num_heads)
    ffn_flops = compute_ffn_flops(batch_size, seq_len, d_model, d_ff)
    
    # RMSNorm operations are relatively cheap (just element-wise ops), so we focus on matrix multiplies
    
    return {
        'attention': attention_flops['total_attention'],
        'ffn': ffn_flops['total_ffn'],
        'total_layer': attention_flops['total_attention'] + ffn_flops['total_ffn'],
        'attention_breakdown': attention_flops,
        'ffn_breakdown': ffn_flops
    }

def compute_transformer_lm_flops(batch_size: int, seq_len: int, vocab_size: int,
                                d_model: int, num_layers: int, num_heads: int, 
                                d_ff: int):
    """
    Compute FLOPs for the complete transformer language model.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length
        vocab_size: Vocabulary size
        d_model: Model dimension
        num_layers: Number of transformer layers
        num_heads: Number of attention heads
        d_ff: Feed-forward dimension
    
    Returns:
        Dictionary with complete FLOP breakdown
    """
    # Token embedding is just indexing (no matrix multiply)
    embedding_flops = 0
    
    # Single layer FLOPs
    layer_flops = compute_transformer_layer_flops(batch_size, seq_len, d_model, num_heads, d_ff)
    
    # All layers
    all_layers_flops = num_layers * layer_flops['total_layer']
    
    # Language model head: [batch, seq_len, d_model] x [d_model, vocab_size]
    lm_head_flops = 2 * batch_size * seq_len * d_model * vocab_size
    
    # Total FLOPs
    total_flops = embedding_flops + all_layers_flops + lm_head_flops
    
    return {
        'embedding': embedding_flops,
        'per_layer': layer_flops['total_layer'],
        'all_layers': all_layers_flops,
        'lm_head': lm_head_flops,
        'total': total_flops,
        'layer_breakdown': layer_flops
    }

In [13]:
def format_flops(flops: int) -> str:
    """
    Format FLOP count in human-readable format.
    
    Args:
        flops: Number of FLOPs
    
    Returns:
        Formatted string with appropriate units
    """
    if flops >= 1e12:
        return f"{flops/1e12:.1f}T"
    elif flops >= 1e9:
        return f"{flops/1e9:.1f}B"
    elif flops >= 1e6:
        return f"{flops/1e6:.1f}M"
    elif flops >= 1e3:
        return f"{flops/1e3:.1f}K"
    else:
        return str(flops)

def print_flop_breakdown(flop_dict, title: str = "FLOP Breakdown"):
    """
    Print a nicely formatted FLOP breakdown.
    
    Args:
        flop_dict: Dictionary with FLOP counts
        title: Title for the breakdown
    """
    print(f"\n{title}:")
    print("=" * len(title + ":"))
    
    for key, value in flop_dict.items():
        if isinstance(value, dict):
            # Skip nested breakdown dictionaries in the main view
            if 'breakdown' in key:
                continue
            print(f"{key.replace('_', ' ').title()}:")
            for sub_key, sub_value in value.items():
                if isinstance(sub_value, (int, float)):
                    print(f"  {sub_key.replace('_', ' ').title()}: {format_flops(sub_value)} ({sub_value:,})")
        else:
            print(f"{key.replace('_', ' ').title()}: {format_flops(value)} ({value:,})")
    print()

In [14]:
def print_detailed_breakdown(flop_dict, title: str = "Detailed FLOP Breakdown"):
    """
    Print a detailed FLOP breakdown including all nested components.
    
    Args:
        flop_dict: Dictionary with FLOP counts
        title: Title for the breakdown
    """
    print(f"\n{title}:")
    print("=" * len(title + ":"))
    
    # Main totals first
    for key, value in flop_dict.items():
        if not isinstance(value, dict) and 'breakdown' not in key:
            print(f"{key.replace('_', ' ').title()}: {format_flops(value)} ({value:,})")
    
    # Then detailed breakdowns
    if 'layer_breakdown' in flop_dict:
        layer_breakdown = flop_dict['layer_breakdown']
        print(f"\nPer-Layer Breakdown:")
        print("-" * 20)
        
        if 'attention_breakdown' in layer_breakdown:
            print("  Attention Components:")
            for k, v in layer_breakdown['attention_breakdown'].items():
                if isinstance(v, (int, float)):
                    print(f"    {k.replace('_', ' ').title()}: {format_flops(v)} ({v:,})")
        
        if 'ffn_breakdown' in layer_breakdown:
            print("  FFN Components:")
            for k, v in layer_breakdown['ffn_breakdown'].items():
                if isinstance(v, (int, float)):
                    print(f"    {k.replace('_', ' ').title()}: {format_flops(v)} ({v:,})")
    
    print()

In [15]:
# ===============================================
# Example: GPT-2 XL FLOP Analysis
# ===============================================

# GPT-2 XL Configuration
gpt2_xl_config = {
    'batch_size': 1,
    'seq_len': 1024,
    'vocab_size': 50257,
    'd_model': 1600,
    'num_layers': 48,
    'num_heads': 25,
    'd_ff': 6400
}

# Compute FLOPs for GPT-2 XL
gpt2_xl_flops = compute_transformer_lm_flops(**gpt2_xl_config)

# Print breakdown
print_flop_breakdown(gpt2_xl_flops, "GPT-2 XL FLOP Breakdown")

print(f"Total FLOPs: {format_flops(gpt2_xl_flops['total'])}")
print(f"FLOPs per token: {format_flops(gpt2_xl_flops['total'] // gpt2_xl_config['seq_len'])}")

# Show detailed component breakdown
print_detailed_breakdown(gpt2_xl_flops, "GPT-2 XL Detailed Breakdown")


GPT-2 XL FLOP Breakdown:
Embedding: 0 (0)
Per Layer: 90.6B (90,596,966,400)
All Layers: 4.3T (4,348,654,387,200)
Lm Head: 164.7B (164,682,137,600)
Total: 4.5T (4,513,336,524,800)

Total FLOPs: 4.5T
FLOPs per token: 4.4B

GPT-2 XL Detailed Breakdown:
Embedding: 0 (0)
Per Layer: 90.6B (90,596,966,400)
All Layers: 4.3T (4,348,654,387,200)
Lm Head: 164.7B (164,682,137,600)
Total: 4.5T (4,513,336,524,800)

Per-Layer Breakdown:
--------------------
  Attention Components:
    Qkv Projections: 15.7B (15,728,640,000)
    Attention Scores: 3.4B (3,355,443,200)
    Attention Output: 3.4B (3,355,443,200)
    Output Projection: 5.2B (5,242,880,000)
    Total Attention: 27.7B (27,682,406,400)
  FFN Components:
    W1 Projection: 21.0B (20,971,520,000)
    W3 Projection: 21.0B (20,971,520,000)
    W2 Projection: 21.0B (20,971,520,000)
    Total Ffn: 62.9B (62,914,560,000)

