# üí¨ Lecture 16: Efficient LLMs - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/16_efficient_llms/demo.ipynb)

## What You'll Learn
- KV cache and its memory implications
- Speculative decoding for faster inference
- Continuous batching
- LLM quantization (GPTQ, AWQ)

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time

print('Ready for Efficient LLMs!')

## Part 1: The LLM Inference Challenge

In [None]:
def llm_inference_analysis():
    """
    Analyze LLM inference characteristics.
    """
    print('üìä LLM INFERENCE CHARACTERISTICS')
    print('=' * 60)
    
    print('\nüîπ Prefill Phase (Process prompt):')
    print('   - Process all prompt tokens in parallel')
    print('   - Compute-bound (matrix multiplications)')
    print('   - Good GPU utilization')
    
    print('\nüîπ Decode Phase (Generate tokens):')
    print('   - Generate ONE token at a time')
    print('   - Memory-bound (load weights for single token)')
    print('   - Poor GPU utilization (~1-5%)')
    
    # Calculate throughput
    models = {
        'GPT-2': {'params_b': 1.5, 'prefill_tok_s': 5000, 'decode_tok_s': 50},
        'LLaMA-7B': {'params_b': 7, 'prefill_tok_s': 2000, 'decode_tok_s': 30},
        'LLaMA-70B': {'params_b': 70, 'prefill_tok_s': 500, 'decode_tok_s': 10},
    }
    
    print('\nüìä INFERENCE SPEED (A100, FP16)')
    print(f'{"Model":<15} {"Params":<10} {"Prefill":<15} {"Decode":<15} {"Slowdown":<10}')
    print('-' * 65)
    for name, info in models.items():
        slowdown = info['prefill_tok_s'] / info['decode_tok_s']
        print(f'{name:<15} {info["params_b"]:<10.1f}B {info["prefill_tok_s"]:>12} tok/s {info["decode_tok_s"]:>10} tok/s {slowdown:>8.0f}x')
    
    print('\n‚ö†Ô∏è Decode is 50-100x slower than prefill!')

llm_inference_analysis()

## Part 2: KV Cache

In [None]:
def kv_cache_memory(model_params_b, n_layers, n_heads, d_head, seq_len, dtype_bytes=2):
    """
    Calculate KV cache memory.
    
    KV cache stores key and value tensors for all past tokens
    to avoid recomputation during autoregressive generation.
    
    Memory = 2 (K+V) √ó n_layers √ó seq_len √ó n_heads √ó d_head √ó dtype_bytes
    """
    cache_mem = 2 * n_layers * seq_len * n_heads * d_head * dtype_bytes
    return cache_mem / 1e9  # GB

# LLaMA architecture
llama_configs = {
    'LLaMA-7B': {'layers': 32, 'heads': 32, 'd_head': 128},
    'LLaMA-13B': {'layers': 40, 'heads': 40, 'd_head': 128},
    'LLaMA-70B': {'layers': 80, 'heads': 64, 'd_head': 128},
}

print('üìä KV CACHE MEMORY ANALYSIS')
print('=' * 70)

seq_lengths = [1024, 4096, 16384, 32768, 131072]

print(f'{"Model":<12}', end='')
for seq in seq_lengths:
    print(f'{seq:>10}', end='')
print('  (tokens)')
print('-' * 70)

for name, config in llama_configs.items():
    print(f'{name:<12}', end='')
    for seq in seq_lengths:
        mem = kv_cache_memory(0, config['layers'], config['heads'], 
                              config['d_head'], seq)
        print(f'{mem:>9.1f}GB', end='')
    print()

print('\n‚ö†Ô∏è KV cache can exceed model size at long contexts!')

In [None]:
# Visualize KV cache scaling
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Memory vs sequence length
seq_range = np.linspace(1024, 131072, 100)
for name, config in llama_configs.items():
    mems = [kv_cache_memory(0, config['layers'], config['heads'], 
                           config['d_head'], int(s)) for s in seq_range]
    axes[0].plot(seq_range/1000, mems, label=name, linewidth=2)

axes[0].axhline(y=80, color='red', linestyle='--', label='A100 80GB')
axes[0].set_xlabel('Sequence Length (K tokens)')
axes[0].set_ylabel('KV Cache Memory (GB)')
axes[0].set_title('KV Cache Memory vs Context Length')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Batch size vs memory
batch_sizes = [1, 2, 4, 8, 16, 32]
seq_len = 4096

for name, config in llama_configs.items():
    mems = [kv_cache_memory(0, config['layers'], config['heads'], 
                           config['d_head'], seq_len) * bs for bs in batch_sizes]
    axes[1].plot(batch_sizes, mems, 'o-', label=name, linewidth=2)

axes[1].axhline(y=80, color='red', linestyle='--', label='A100 80GB')
axes[1].set_xlabel('Batch Size')
axes[1].set_ylabel('KV Cache Memory (GB)')
axes[1].set_title('KV Cache Memory vs Batch Size (4K context)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Part 3: Speculative Decoding

In [None]:
def speculative_decoding_demo():
    """
    Demonstrate speculative decoding concept.
    
    Key idea:
    1. Use small draft model to generate K tokens quickly
    2. Verify all K tokens with large model in ONE forward pass
    3. Accept verified tokens, reject and regenerate if needed
    """
    print('üìä SPECULATIVE DECODING')
    print('=' * 60)
    
    print('\nüîπ Standard Decoding (7B model):')
    print('   Token 1: [Forward 7B] ‚Üí "The"')
    print('   Token 2: [Forward 7B] ‚Üí "quick"')
    print('   Token 3: [Forward 7B] ‚Üí "brown"')
    print('   Token 4: [Forward 7B] ‚Üí "fox"')
    print('   Total: 4 forward passes')
    
    print('\nüîπ Speculative Decoding:')
    print('   Draft (68M): [Quick] ‚Üí "The quick brown fox" (4 tokens)')
    print('   Target (7B): [Verify] ‚Üí Accept "The quick brown" (3 accepted)')
    print('   Total: 1 draft + 1 target forward pass = 3 tokens!')
    
    # Simulate speedup
    target_time = 100  # ms per token
    draft_time = 5     # ms per token
    K = 4              # speculation depth
    accept_rate = 0.8  # acceptance rate
    
    # Standard: K tokens √ó target_time
    standard_time = K * target_time
    
    # Speculative: K √ó draft_time + 1 √ó target_time
    # Expected accepted = K √ó accept_rate
    spec_time = K * draft_time + target_time
    expected_tokens = K * accept_rate + 1  # Plus at least 1 from target
    
    print(f'\nüìä SPEEDUP ANALYSIS')
    print(f'Standard: {K} tokens in {standard_time}ms ({standard_time/K:.0f}ms/token)')
    print(f'Speculative: {expected_tokens:.1f} tokens in {spec_time}ms ({spec_time/expected_tokens:.0f}ms/token)')
    print(f'Speedup: {(standard_time/K) / (spec_time/expected_tokens):.1f}x')

speculative_decoding_demo()

In [None]:
# Simulate speculative decoding
def simulate_speculative(n_tokens, accept_rate, K=4, target_time=100, draft_time=5):
    """
    Simulate speculative decoding throughput.
    """
    total_time = 0
    generated = 0
    
    while generated < n_tokens:
        # Draft K tokens
        total_time += K * draft_time
        
        # Verify with target
        total_time += target_time
        
        # Accept based on rate
        n_accepted = int(K * accept_rate) + 1  # Plus bonus token from target
        generated += min(n_accepted, n_tokens - generated)
    
    return total_time, generated

# Compare methods
n_tokens = 100
accept_rates = [0.5, 0.6, 0.7, 0.8, 0.9]

standard_time = n_tokens * 100  # 100ms per token

print('üìä SPECULATIVE DECODING SPEEDUP')
print('=' * 50)
print(f'{"Accept Rate":<15} {"Time (ms)":<15} {"Speedup":<15}')
print('-' * 45)

for rate in accept_rates:
    spec_time, _ = simulate_speculative(n_tokens, rate)
    speedup = standard_time / spec_time
    print(f'{rate:<15.0%} {spec_time:<15.0f} {speedup:<15.1f}x')

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

rates = np.linspace(0.3, 0.95, 20)
speedups = [standard_time / simulate_speculative(n_tokens, r)[0] for r in rates]

ax.plot(rates * 100, speedups, 'o-', color='#3b82f6', linewidth=2)
ax.axhline(y=1, color='red', linestyle='--', label='Standard decoding')
ax.set_xlabel('Acceptance Rate (%)')
ax.set_ylabel('Speedup')
ax.set_title('üìä Speculative Decoding Speedup vs Acceptance Rate')
ax.grid(True, alpha=0.3)
ax.legend()

plt.tight_layout()
plt.show()

## Part 4: LLM Quantization (GPTQ, AWQ)

In [None]:
def llm_quantization_comparison():
    """
    Compare LLM quantization methods.
    """
    methods = {
        'FP16': {'bits': 16, 'perplexity_delta': 0, 'speed': 1.0},
        'INT8 (LLM.int8)': {'bits': 8, 'perplexity_delta': 0.1, 'speed': 1.2},
        'GPTQ 4-bit': {'bits': 4, 'perplexity_delta': 0.3, 'speed': 2.5},
        'AWQ 4-bit': {'bits': 4, 'perplexity_delta': 0.2, 'speed': 2.8},
        'GGML Q4_K_M': {'bits': 4.5, 'perplexity_delta': 0.25, 'speed': 3.0},
    }
    
    print('üìä LLM QUANTIZATION METHODS')
    print('=' * 70)
    print(f'{"Method":<20} {"Bits":<10} {"PPL Delta":<15} {"Speed vs FP16":<15}')
    print('-' * 70)
    
    for name, info in methods.items():
        print(f'{name:<20} {info["bits"]:<10.1f} {info["perplexity_delta"]:>+12.2f}   {info["speed"]:>12.1f}x')
    
    # Memory comparison for LLaMA-70B
    print('\nüìä LLaMA-70B MEMORY')
    print('-' * 40)
    base_mem = 70 * 2  # 70B √ó 2 bytes (FP16)
    
    for name, info in methods.items():
        mem = 70 * info['bits'] / 8
        print(f'{name:<20}: {mem:.0f} GB')

llm_quantization_comparison()

In [None]:
# Visualize quantization trade-offs
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Memory comparison
methods = ['FP16', 'INT8', '4-bit GPTQ', '4-bit AWQ', '3-bit']
llama_70b_mem = [140, 70, 35, 35, 26.25]  # GB
llama_7b_mem = [14, 7, 3.5, 3.5, 2.625]   # GB

x = np.arange(len(methods))
width = 0.35

bars1 = axes[0].bar(x - width/2, llama_70b_mem, width, label='LLaMA-70B', color='#ef4444')
bars2 = axes[0].bar(x + width/2, llama_7b_mem, width, label='LLaMA-7B', color='#3b82f6')

axes[0].axhline(y=24, color='green', linestyle='--', label='RTX 4090 (24GB)')
axes[0].axhline(y=80, color='orange', linestyle='--', label='A100 (80GB)')

axes[0].set_ylabel('Memory (GB)')
axes[0].set_title('Model Memory by Precision')
axes[0].set_xticks(x)
axes[0].set_xticklabels(methods, rotation=15)
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# Quality vs Size trade-off
bits = [16, 8, 4, 3, 2]
ppl_delta = [0, 0.1, 0.3, 1.0, 3.0]  # Perplexity increase

axes[1].plot(bits, ppl_delta, 'o-', color='#ef4444', linewidth=2, markersize=10)
axes[1].set_xlabel('Bits per Weight')
axes[1].set_ylabel('Perplexity Increase')
axes[1].set_title('Quality Degradation vs Quantization')
axes[1].grid(True, alpha=0.3)
axes[1].invert_xaxis()

# Annotate
for b, p in zip(bits, ppl_delta):
    quality = 'Good' if p < 0.5 else ('OK' if p < 1.5 else 'Poor')
    axes[1].annotate(quality, (b, p), xytext=(10, 5), textcoords='offset points')

plt.tight_layout()
plt.show()

In [None]:
print('üéØ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. LLM decode is memory-bound (50-100x slower than prefill)')
print('\n2. KV cache memory grows linearly with sequence length')
print('\n3. Speculative decoding: 2-3x faster with draft model')
print('\n4. 4-bit quantization: 4x less memory, minimal quality loss')
print('\n5. Combine techniques: quant + spec decode + paged attention')
print('\n6. vLLM, TensorRT-LLM implement all optimizations')
print('\n' + '=' * 60)
print('\nüìö Next: Efficient Diffusion Models!')