In [None]:
from cs336_basics.model import Transformer

In [None]:
GPT2 = {"vocab_size": 50257, "context_length": 1024, "d_ff": 6400}
CONFIGS = {"XL": {"num_layers": 48, "d_model": 1600, "num_heads": 25},
           "XL-CONT": {"context_length": 16384},
           "large": {"num_layers": 36, "d_model": 1280, "num_heads": 20},
           "medium": {"num_layers": 24, "d_model": 1024, "num_heads": 16},
           "small": {"num_layers": 12, "d_model": 768, "num_heads": 12},
          }
CONFIGS["XL-CONT"].update(CONFIGS["XL"])

In [None]:
def analyze_model(vocab_size, context_length, d_ff, num_layers, d_model, num_heads):
    emb_params = vocab_size * d_model
    ln_params = d_model
    sa_params = 4 * d_model * d_model
    ffn_params = 3 * d_model * d_ff
    layer_params = 2 * ln_params + sa_params + ffn_params
    lm_head_params = d_model * vocab_size
    exp_params = emb_params + num_layers * layer_params + ln_params + lm_head_params

    model = Transformer(rope_theta=10000, vocab_size=vocab_size, context_length=context_length, d_ff=d_ff, 
                        num_layers=num_layers, d_model=d_model, num_heads=num_heads)
    act_params = sum(p.numel() for p in model.parameters())
    assert exp_params == act_params    

    print(f"Number of parameters = {exp_params:,}, memory = {exp_params * 4:,} (not counting RoPE buffers and other overhead)")
    print(f"Formula estimate of FLOPs: 2 * context_length * num_params = {2 * context_length * exp_params:,}")
    
    # Layer FLOPS
    qkv_flops = 2 * context_length * d_model * 3 * d_model
    scaled_dp_attn_flops = 2 * 2 * context_length * context_length * d_model  # compute and apply the score matrix
    out_flops = 2 * context_length * d_model * d_model
    attn_flops = qkv_flops + scaled_dp_attn_flops + out_flops  # 8 * context_length * d_model^2 + 4 * context_length^2 * d_model
    ffn_flops = 2 * 3 * context_length * d_model * d_ff  # 6 * context_length * d_model * d_ff
    layer_flops = attn_flops + ffn_flops

    # Proportions
    attn_flops_perc = attn_flops / layer_flops * 100
    qkv_flops_perc = qkv_flops / layer_flops * 100
    scaled_dp_attn_flops_perc = scaled_dp_attn_flops / layer_flops * 100
    out_flops_perc = out_flops / layer_flops * 100
    ffn_flops_perc = ffn_flops / layer_flops * 100
    
    # Transformer FLOPS
    all_layers_flops = num_layers * layer_flops
    lm_head_flops = context_length * d_model * vocab_size
    transformer_flops = all_layers_flops + lm_head_flops

    # Proportions
    all_layers_flops_perc = all_layers_flops / transformer_flops * 100
    lm_head_flops_perc = lm_head_flops / transformer_flops * 100 

    print("\nManual estimate of FLOPs")
    print(f"Transformer: {transformer_flops:,} ({num_layers:,} layers: "
          f"{all_layers_flops_perc:.1f}%, lm_head: {lm_head_flops_perc:.1f}%)")
    print(f"  Layer: {layer_flops:,}")
    print(f"    Attention: {attn_flops_perc:.1f}% (qkv: {qkv_flops_perc:.1f}%, "
          f"scaled_dp_attn: {scaled_dp_attn_flops_perc:.1f}%, out: {out_flops_perc:.1f}%)")
    print(f"    FFN: {ffn_flops_perc:.1f}%")

In [None]:
for name, config in CONFIGS.items():
    print(f"\n======== {name} ========\n")
    analyze_model(**(GPT2 | config))