In [1]:
from transformers import AutoTokenizer

# Load the Llama 3.2 tokenizer (requires Hugging Face login)
model_id = "meta-llama/Llama-3.2-3B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

text = "i was also speaking in longer blocks of words and that makes it difficult to correct anything to one hundred percent correct."

# 1. Encode the text into token IDs
tokens = tokenizer.encode(text, add_special_tokens=False)

# 2. Count them
print(f"Token count: {len(tokens)}")

# Optional: See what the tokens actually are
# print(tokenizer.convert_ids_to_tokens(tokens))

  from .autonotebook import tqdm as notebook_tqdm


Token count: 23


In [3]:
def calculate_kv_memory_usage(num_beams=2000, max_tokens=50, precision="fp16"):
    """
    Calculates the VRAM usage for Llama 3.2 3B KV Cache.
    
    Architecture Reference (Llama 3.2 3B):
    - Layers: 28
    - KV Heads: 8 (Uses Grouped Query Attention)
    - Head Dim: 128
    """
    # 1. Model Configuration
    NUM_LAYERS = 28
    NUM_KV_HEADS = 8
    HEAD_DIM = 128
    
    # 2. Precision
    if precision == "fp16" or precision == "bf16":
        bytes_per_param = 2
    elif precision == "fp32":
        bytes_per_param = 4
    elif precision == "int8":
        bytes_per_param = 1
    else:
        raise ValueError("Unknown precision")

    # 3. Calculate Size
    # We store 2 matrices (Key + Value) per layer
    # Size = (K + V) * Layers * Heads * Dim * Bytes
    size_per_token = 2 * NUM_LAYERS * NUM_KV_HEADS * HEAD_DIM * bytes_per_param
    
    total_bytes = size_per_token * num_beams * max_tokens
    total_gb = total_bytes / (1024**3)
    
    print(f"--- KV Cache Memory Cost ---")
    print(f"Config: {num_beams} Beams x {max_tokens} Tokens ({precision})")
    print(f"Memory: {total_gb:.2f} GB")
    
    return total_gb

# Example Usage
usage = calculate_kv_memory_usage(num_beams=2000, max_tokens=40, precision="fp16")

# Check if it fits on your GPU (e.g. 24GB)
# Model Weights ~6GB + Cache ~10.7GB = ~16.7GB Total

--- KV Cache Memory Cost ---
Config: 2000 Beams x 40 Tokens (fp16)
Memory: 8.54 GB


In [None]:
import torch
import time

# --- Hardware Check ---
if not torch.cuda.is_available():
    print("CUDA not detected. This script requires a GPU.")
    exit()

device = torch.device("cuda:0")
props = torch.cuda.get_device_properties(device)
print(f"Running on: {props.name}")
print(f"Total VRAM: {props.total_memory / 1024**3:.2f} GB")

# --- Configuration (Matches Llama 3.2 3B) ---
NUM_BEAMS = 2000       # The massive beam count
MAX_LEN = 50           # Max sequence length
DTYPE = torch.float16  # Standard precision

# Llama 3.2 3B Specs
NUM_LAYERS = 28
NUM_KV_HEADS = 8       # Grouped Query Attention (GQA)
HEAD_DIM = 128         # Standard Llama head dim

# --- 1. Allocate the "Static KV Cache" ---
print(f"\nAllocating Static KV Cache for {NUM_BEAMS} beams x {MAX_LEN} tokens...")

# We calculate the exact shape.
# Shape: [Layers, 2 (K+V), Beams, Heads, Max_Len, Head_Dim]
# Note: In a real implementation, you might split 'Layers' into a list to avoid one contiguous 10GB block,
# but for this test, a single block ensures we test the worst-case memory fragmentation.
cache_shape = (NUM_LAYERS, 2, NUM_BEAMS, NUM_KV_HEADS, MAX_LEN, HEAD_DIM)

try:
    # Initialize with zeros (or random) in VRAM
    kv_cache = torch.zeros(cache_shape, dtype=DTYPE, device=device)
    
    # Calculate size in GB
    num_elements = kv_cache.numel()
    size_gb = (num_elements * 2) / (1024**3) # 2 bytes for float16
    print(f"✅ Cache Allocated Successfully!")
    print(f"   Size: {size_gb:.2f} GB")
    print(f"   Shape: {cache_shape}")
    
except torch.cuda.OutOfMemoryError:
    print("❌ OOM! The cache is too big for your GPU.")
    exit()

# --- 2. The Simulation Loop ---
print(f"\nStarting Beam Search Simulation ({MAX_LEN} steps)...")

# Timers
total_reorder_time = 0
start_benchmark = time.time()

# Dummy "New Token" data (simulating output from one step of the model)
# Shape per step: [Layers, 2, Beams, Heads, 1, Dim]
new_token_update = torch.randn(
    (NUM_LAYERS, 2, NUM_BEAMS, NUM_KV_HEADS, 1, HEAD_DIM), 
    dtype=DTYPE, 
    device=device
)

for step in range(MAX_LEN):
    # A. SIMULATE COMPUTE (Model Forward Pass)
    # We pretend the model ran and gave us new KV data.
    # We write this new data into the static cache at index 'step'.
    
    # Write to cache (No cost, just indexing)
    kv_cache[:, :, :, :, step:step+1, :] = new_token_update

    # B. SIMULATE BEAM SELECTION (The "Shuffle" Logic)
    # Generate random "parent indices" to simulate beams splitting/dying.
    # This represents: "Beam 0 survived, Beam 1 died and was replaced by a copy of Beam 0"
    parent_indices = torch.randint(0, NUM_BEAMS, (NUM_BEAMS,), device=device)
    
    # C. REORDER CACHE (The Critical Benchmark)
    # We must shuffle the history [0...step] to match the new parents.
    
    torch.cuda.synchronize()
    reorder_start = time.time()
    
    # 1. Slice valid history (Layers, KV, Beams, Heads, 0:step+1, Dim)
    # 2. Select along the BEAM dimension (Dim 2)
    # 3. Write back in-place
    
    # Optimized View-based Select
    # We gather everything up to the current step.
    # Note: 'index_select' is often faster than fancy slicing for this specific shape.
    
    current_history = kv_cache[:, :, :, :, :step+1, :]
    
    # The heavy lift: Shuffling 10GB of data around
    reordered_history = torch.index_select(current_history, 2, parent_indices)
    
    # Write back
    kv_cache[:, :, :, :, :step+1, :] = reordered_history
    
    torch.cuda.synchronize()
    reorder_end = time.time()
    
    total_reorder_time += (reorder_end - reorder_start)
    
    if step % 10 == 0:
        print(f"Step {step}/{MAX_LEN} - Reorder Time: {(reorder_end - reorder_start)*1000:.2f} ms")

total_benchmark = time.time() - start_benchmark

# --- 3. Results ---
avg_reorder = (total_reorder_time / MAX_LEN) * 1000
print(f"\n--- RTX 5090 Simulation Results ---")
print(f"Total Cache Size:      {size_gb:.2f} GB")
print(f"Avg Cache Reorder Time: {avg_reorder:.2f} ms per step")
print(f"Total Time (50 steps):  {total_benchmark:.2f} s")

# Context for latency
print(f"\n--- Latency Context ---")
print(f"Typical Llama 3B Compute Time (per step): ~50.00 ms")
print(f"Your Cache Overhead: {avg_reorder:.2f} ms")
print(f"Overhead Percentage: {(avg_reorder / 50.0) * 100:.1f}%")

if avg_reorder < 5:
    print("\n✅ VERDICT: EXTREMELY FAST. The cache reordering is negligible.")
    print("   You can run 2000 beams without complex optimization.")
else:
    print("\n⚠️ VERDICT: NOTICEABLE. You might want to look into PagedAttention.")

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Replace with your local path or model ID
model_id = "meta-llama/Llama-3.2-3B" 

print(f"Loading {model_id}...")

# Load model - mimicking your likely production setup
# Note: behavior can change if you use attn_implementation="eager" vs "sdpa" vs "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Prepare dummy input
input_text = "Testing the cache type"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

# Run forward pass
with torch.no_grad():
    outputs = model(**inputs, use_cache=True)

# INSPECTION
cache = outputs.past_key_values

print("\n--- RESULTS ---")
print(f"Type of past_key_values: {type(cache)}")

if hasattr(cache, "reorder_cache"):
    print("✅ contains 'reorder_cache' method (It is a Dynamic/Sink/Static Cache object)")
else:
    print("❌ NO 'reorder_cache' method (It is likely a Tuple)")

# If it's a tuple, let's print the structure
if isinstance(cache, tuple):
    print(f"Tuple structure: {len(cache)} layers")
    if len(cache) > 0:
        print(f"Layer 0 type: {type(cache[0])}")

  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!


Loading meta-llama/Llama-3.2-3B...


Skipping import of cpp extensions due to incompatible torch version 2.8.0+cu128 for torchao version 0.14.1             Please see https://github.com/pytorch/ao/issues/2919 for more info
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.12it/s]



--- RESULTS ---
Type of past_key_values: <class 'transformers.cache_utils.DynamicCache'>
✅ contains 'reorder_cache' method (It is a Dynamic/Sink/Static Cache object)
