# Lecture 16: Efficient Large Language Models

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/16_efficient_llms/demo.ipynb)

KV cache, speculative decoding, and LLM serving optimizations.


In [None]:
!pip install torch -q
import torch

# KV Cache Memory Calculator
def kv_cache_memory_gb(model_params, layers, heads, d_head, seq_len, dtype_bytes=2):
    """
    KV cache size = 2 * layers * seq_len * d_model * dtype
    (2 for K and V)
    """
    d_model = heads * d_head
    kv_size = 2 * layers * seq_len * d_model * dtype_bytes
    return kv_size / 1e9

# LLaMA model configs
configs = {
    'LLaMA-7B': {'layers': 32, 'heads': 32, 'd_head': 128},
    'LLaMA-13B': {'layers': 40, 'heads': 40, 'd_head': 128},
    'LLaMA-70B': {'layers': 80, 'heads': 64, 'd_head': 128},
}

print("KV Cache Memory (FP16)")
print("=" * 55)
print(f"{'Model':<12} | {'2K ctx':>8} | {'8K ctx':>8} | {'32K ctx':>8}")
print("-" * 55)

for name, cfg in configs.items():
    for seq_len in [2048, 8192, 32768]:
        mem = kv_cache_memory_gb(0, cfg['layers'], cfg['heads'], cfg['d_head'], seq_len)
        if seq_len == 2048:
            print(f"{name:<12} | {mem:>7.1f}G", end="")
        else:
            print(f" | {mem:>7.1f}G", end="")
    print()

print("\nðŸŽ¯ KV cache can exceed model weights for long contexts!")


In [None]:
# Speculative Decoding Demo
import random

def speculative_decode_demo():
    """Simulate speculative decoding speedup"""
    
    # Draft model: fast but less accurate
    # Main model: slow but accurate
    
    draft_time = 0.01  # 10ms per token
    main_time = 0.1    # 100ms per token
    
    # Generate 100 tokens
    total_tokens = 100
    
    # Standard decoding
    standard_time = total_tokens * main_time
    
    # Speculative decoding
    k = 5  # Draft 5 tokens at a time
    accept_rate = 0.7  # 70% accepted on average
    
    spec_time = 0
    generated = 0
    while generated < total_tokens:
        # Draft k tokens
        spec_time += k * draft_time
        # Verify with main model (1 forward pass)
        spec_time += main_time
        # Accept ~70% of k tokens on average
        accepted = int(k * accept_rate)
        generated += max(accepted, 1)  # At least 1 token
    
    speedup = standard_time / spec_time
    
    print("Speculative Decoding Simulation")
    print(f"  Standard: {standard_time:.1f}s for {total_tokens} tokens")
    print(f"  Speculative: {spec_time:.1f}s for {total_tokens} tokens")
    print(f"  Speedup: {speedup:.1f}x")
    print("\nðŸŽ¯ Same quality, much faster generation!")

speculative_decode_demo()
