# MLA-Retrofit Demo Notebook

This notebook demonstrates how to use MLA-Retrofit to convert a model from GQA to MLA and test the results.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/manncodes/mla-retrofit/blob/main/examples/mla_retrofit_demo.ipynb)

## Setup

First, let's install the MLA-Retrofit package and its dependencies.

In [None]:
# Install MLA-Retrofit
!pip install git+https://github.com/manncodes/mla-retrofit.git

# Install additional dependencies
!pip install accelerate
!pip install bitsandbytes>=0.40.0  # For quantization

## Import Libraries

In [2]:
import pathlib

# get curr working dir
cwd = pathlib.Path.cwd()
print(cwd)

c:\sandbox\repo\mla-retrofit


In [3]:
import torch
import os
import logging
from mla_retrofit import convert_to_mla
from transformers import AutoModelForCausalLM, AutoTokenizer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("mla-demo")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

2025-03-23 17:01:11,188 - mla-demo - INFO - Using device: cuda


## Choose Model and Parameters

Let's set up the parameters for conversion. We'll use a small model for demonstration purposes.

In [19]:
# For this demo, we'll use a smaller model that can fit in Colab's memory
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # Small model for demo
output_dir = "./tinyllama-mla"

# MLA parameters
num_kv_heads = 4  # Number of KV heads for MLA
head_dim = 64     # Head dimension for MLA
rope_mode = "extend"  # Mode for RoPE handling
absorb = True     # Whether to absorb projection matrices

## Load and Examine Original Model

Before conversion, let's load the original model and examine its configuration.

In [11]:
# Load original model
logger.info(f"Loading original model: {model_name}")
original_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Use half precision to save memory
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Print model configuration
print("\nOriginal Model Configuration:")
print(f"Hidden size: {original_model.config.hidden_size}")
print(f"Number of attention heads: {original_model.config.num_attention_heads}")
print(f"Number of KV heads: {getattr(original_model.config, 'num_key_value_heads', original_model.config.num_attention_heads)}")
print(f"Head dimension: {original_model.config.hidden_size // original_model.config.num_attention_heads}")

# Calculate memory requirements
seq_len = 1024
kv_heads = getattr(original_model.config, 'num_key_value_heads', original_model.config.num_attention_heads)
n_attn_layers = original_model.config.num_hidden_layers
head_dim_orig = original_model.config.hidden_size // original_model.config.num_attention_heads
kv_cache_size_original = 2 * 2 * seq_len * kv_heads * head_dim_orig  * n_attn_layers  # 2 for key and value, 2 for FP16
print(f"KV cache size: 2 * 2 * {seq_len}(seq_len) * {kv_heads}(kv_heads) * {head_dim_orig}(head_dim) * {n_attn_layers}(n_attn_layers) = {kv_cache_size_original / (1024 * 1024):.2f} MB")
print(f"\nKV cache size for {seq_len} tokens (FP16): {kv_cache_size_original / (1024 * 1024):.2f} MB")
print(f"per token KV cache size: {kv_cache_size_original / (1024 * 1024 * seq_len):.2f} MB")

2025-03-23 17:13:14,225 - mla-demo - INFO - Loading original model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2025-03-23 17:13:14,555 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).



Original Model Configuration:
Hidden size: 2048
Number of attention heads: 32
Number of KV heads: 4
Head dimension: 64
KV cache size: 2 * 2 * 1024(seq_len) * 4(kv_heads) * 64(head_dim) * 22(n_attn_layers) = 22.00 MB

KV cache size for 1024 tokens (FP16): 22.00 MB
per token KV cache size: 0.02 MB


## Generate Text with Original Model

Let's generate some text with the original model to compare it with the MLA version later.

In [12]:
# Set prompt
prompt = "Explain the advantages of Multi-head Latent Attention in language models in simple terms."

# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Generate with original model
logger.info("Generating text with original model...")
with torch.no_grad():
    outputs_original = original_model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
    )

# Decode and print
generated_text_original = tokenizer.decode(outputs_original[0], skip_special_tokens=True)
print(f"\nOriginal Model Output:\n{generated_text_original}")

# Free up memory
del original_model
torch.cuda.empty_cache()

2025-03-23 17:13:35,977 - mla-demo - INFO - Generating text with original model...



Original Model Output:
Explain the advantages of Multi-head Latent Attention in language models in simple terms. I hope this helps!


In [28]:
# Convert model to MLA
%reload_ext autoreload
%autoreload 2

from mla_retrofit import convert_to_mla

logger.info(f"Converting {model_name} to MLA...")

model, tokenizer = convert_to_mla(
    model_name_or_path=model_name,
    num_kv_heads=num_kv_heads,
    head_dim=head_dim,
    rope_mode=rope_mode,
    absorb=absorb,
    flash_attn=False,  # Set to True if you have Flash Attention installed
    return_model=True,
)

# Print MLA model configuration
print("\nMLA Model Configuration:")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Number of attention heads: {model.config.num_attention_heads}")
print(f"Number of KV heads: {model.config.num_key_value_heads}")
print(f"Head dimension: {model.config.head_dim}")

# Calculate memory requirements
kv_cache_size_mla = 2 * seq_len * num_kv_heads * head_dim * 2  # 2 for K and V, 2 bytes for float16
print(f"\nKV cache size for {seq_len} tokens (FP16): {kv_cache_size_mla / (1024 * 1024):.2f} MB")
print(f"Memory reduction: {(1 - kv_cache_size_mla / kv_cache_size_original) * 100:.2f}%")

2025-03-23 17:42:03,284 - mla-demo - INFO - Converting TinyLlama/TinyLlama-1.1B-Chat-v1.0 to MLA...
2025-03-23 17:42:03,593 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2025-03-23 17:42:10,914 - mla_retrofit.convert - INFO - Original head dimension: 64
2025-03-23 17:42:10,916 - mla_retrofit.convert - INFO - Original KV heads: 4
2025-03-23 17:42:10,916 - mla_retrofit.convert - INFO - Target latent dimension: 256
2025-03-23 17:42:10,918 - mla_retrofit.convert - INFO - Target KV heads: 4
2025-03-23 17:42:10,920 - mla_retrofit.convert - INFO - Target head dimension: 64
2025-03-23 17:42:10,921 - mla_retrofit.convert - INFO - Module model.layers.0.self_attn does not have MLA structure (missing k_up_proj or v_up_proj)
2025-03-23 17:42:10,922 - mla_retrofit.convert - INFO - Model has MLA structure: False
2025-03-23 17


MLA Model Configuration:
Hidden size: 2048
Number of attention heads: 32
Number of KV heads: 4
Head dimension: 64

KV cache size for 1024 tokens (FP16): 1.00 MB
Memory reduction: 95.45%


In [31]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): Llam

## Generate Text with MLA Model

Now let's generate text with the MLA-converted model and compare the results.

In [None]:
# Generate with MLA model
logger.info("Generating text with MLA model...")
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    logger.info(f"Logits: {logits}")

    # Check for invalid values in logits
    if torch.isnan(logits).any() or torch.isinf(logits).any():
        raise ValueError("Logits contain NaN or Inf values.")

    outputs_mla = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
    )

# Decode and print
generated_text_mla = tokenizer.decode(outputs_mla[0], skip_special_tokens=True)
print(f"\nMLA Model Output:\n{generated_text_mla}")

## Save the Converted Model

Let's save the MLA-converted model for future use.

In [None]:
# Save the model
logger.info(f"Saving MLA model to {output_dir}")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

## Generate Text with Longer Context

One benefit of MLA is its ability to handle longer contexts with the same memory. Let's test this with a longer prompt.

In [None]:
# Create a longer prompt (repeated text for demo purposes)
long_prompt = prompt + "\n\n" + "\n\n".join([f"Section {i+1}: " + prompt for i in range(10)])
print(f"Prompt length: {len(tokenizer.encode(long_prompt))} tokens")

# Tokenize long prompt
long_inputs = tokenizer(long_prompt, return_tensors="pt").to(device)

# Generate with MLA model on longer context
logger.info("Generating text with MLA model on longer context...")
with torch.no_grad():
    outputs_long = model.generate(
        **long_inputs,
        max_new_tokens=50,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
    )

# Decode and print
generated_text_long = tokenizer.decode(outputs_long[0], skip_special_tokens=True)
# Just print the generated part (not the full prompt)
print(f"\nMLA Model Output on Longer Context:\n{generated_text_long[len(long_prompt):]}")

## Memory Profile and Benchmark

Now let's profile the memory usage and generation speed.

In [None]:
# Memory profiling
if torch.cuda.is_available():
    print(f"\nCUDA Memory Stats:")
    print(f"Allocated: {torch.cuda.memory_allocated() / (1024 * 1024):.2f} MB")
    print(f"Cached: {torch.cuda.memory_reserved() / (1024 * 1024):.2f} MB")
    
    # Simple generation speed benchmark
    import time
    
    # Warm-up
    with torch.no_grad():
        model.generate(**inputs, max_new_tokens=10)
    
    # Benchmark
    num_runs = 5
    total_time = 0
    total_tokens = 0
    
    print(f"\nBenchmarking generation speed (average of {num_runs} runs):")
    for i in range(num_runs):
        torch.cuda.synchronize()
        start_time = time.time()
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=20)
            
        torch.cuda.synchronize()
        end_time = time.time()
        
        tokens_generated = outputs.shape[1] - inputs.input_ids.shape[1]
        run_time = end_time - start_time
        total_time += run_time
        total_tokens += tokens_generated
        
        print(f"Run {i+1}: Generated {tokens_generated} tokens in {run_time:.4f}s ({tokens_generated/run_time:.2f} tokens/s)")
    
    print(f"\nAverage: {total_tokens/total_time:.2f} tokens/s")

## Fine-tuning the MLA Model (Optional)

For even better performance, you may want to fine-tune the converted model. Here's a simplified example.

In [None]:
# NOTE: This cell is for demonstration only and won't run well in Colab without additional setup
# Uncomment and run if you have sufficient GPU memory and want to try fine-tuning

'''
from datasets import load_dataset
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

# Load a small dataset for fine-tuning (using alpaca as an example)
dataset = load_dataset("tatsu-lab/alpaca", split="train[:100]")  # Just use 100 examples for demo

# Format the dataset
def format_prompt(example):
    return {
        "text": f"### Instruction: {example['instruction']}\n\n### Input: {example['input']}\n\n### Response: {example['output']}"
    }

# Apply formatting and tokenization
formatted_dataset = dataset.map(format_prompt)
tokenized_dataset = formatted_dataset.map(
    lambda examples: tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length",
    ),
    batched=True,
    remove_columns=["instruction", "input", "output", "text"],
)

# Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Set up training arguments
training_args = TrainingArguments(
    output_dir="./tinyllama-mla-finetuned",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_steps=10,  # Just a few steps for the demo
    logging_steps=1,
    save_steps=5,
    save_total_limit=1,
    fp16=True,
    remove_unused_columns=False,
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

# Start training
trainer.train()

# Save fine-tuned model
model.save_pretrained("./tinyllama-mla-finetuned")
'''

## Conclusion

In this notebook, we've demonstrated how to:

1. Convert a model from standard attention or GQA to MLA
2. Compare the memory usage before and after conversion
3. Test the model on both short and longer context prompts
4. Benchmark generation speed
5. (Optionally) Fine-tune the converted model

MLA-Retrofit provides a simple way to enhance existing models by adding the benefits of Multi-head Latent Attention without requiring full retraining.