# Jupyter Notebook: Custom Attention Optimization

Description:
------------
In this notebook, we will:
1. Load a pre-built language model (LLM).
2. Create a copy of the model architecture but replace its attention mechanism with a simplified one that only attends to the last 5 tokens (instead of all previous tokens).
3. Implement a process to compare the outputs of both models and compute a KL-divergence loss.
4. Optimize the custom model's parameters by minimizing the KL-divergence between the two models’ distributions.
5. Demonstrate how to evaluate and compare both models on sample data.

In [67]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [20]:
########################################
# Cell 1: Environment Setup
########################################

# If you do not have 'transformers' installed, uncomment the pip install line below.
# !pip install torch transformers

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)

print("Environment setup complete!")

Environment setup complete!


In [79]:
########################################
# Cell 2: Configuration & Hyperparameters
########################################

# We define some hyperparameters and configurations that we will use throughout the notebook.

MODEL_NAME = "gpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2
LR = 3e-5
MAX_SEQ_LENGTH = 32
NUM_EPOCHS = 100
SHOW_SAMPLE_OUTPUTS = True # Whether to show text generation samples for comparison
GRAD_CLIP = 1.0

optimizer = optim.AdamW(custom_model.parameters(), lr=LR, weight_decay=0.01)

print(f"Using device: {DEVICE}")
print("Configurations and hyperparameters set!")

Using device: cuda
Configurations and hyperparameters set!


In [3]:
########################################
# Cell 3: Loading the Pre-built LLM
########################################

# We load a Hugging Face GPT-2 model as our "pre-built" LLM.
# This is the reference model, which we will attempt to replicate with a custom attention mechanism.

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
reference_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
reference_model.to(DEVICE)
reference_model.eval()  # We will not train this reference model; we only use it for KL-divergence comparisons.

print(f"Reference model '{MODEL_NAME}' loaded successfully!")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Reference model 'gpt2' loaded successfully!


In [74]:
########################################
# Cell 4: Custom Attention Mechanism (Fixed)
########################################

class LastNTokensAttention(nn.Module):
    """
    A simplified attention mechanism that only attends to the last N tokens (specified by window_size).
    Handles GPT-2's expected arguments while maintaining custom functionality.
    """
    def __init__(self, d_model, num_heads, window_size=10):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.window_size = window_size


        # Projection layers
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        batch_size, seq_length, _ = hidden_states.size()

        # Project Q, K, V
        q = self.query_proj(hidden_states)
        k = self.key_proj(hidden_states)
        v = self.value_proj(hidden_states)

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Create sliding window attention mask
        full_mask = torch.full((batch_size, self.num_heads, seq_length, seq_length),
                             float('-inf'), device=hidden_states.device)
        for i in range(seq_length):
            start_idx = max(0, i - self.window_size)
            full_mask[:, :, i, start_idx:i+1] = 0

        # Combine with provided attention mask
        if attention_mask is not None:
            causal_mask = attention_mask.to(hidden_states.dtype)
            causal_mask = causal_mask.repeat(1, self.num_heads, 1, 1)
            full_mask = full_mask + causal_mask

        # Compute attention scores
        k_transposed = k.transpose(-2, -1)
        scaling_factor = self.head_dim ** 0.5
        attention_scores = torch.matmul(q, k_transposed) / scaling_factor
        attention_scores += full_mask

        # Get attention probabilities
        attention_probs = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_probs, v)

        # Combine heads and project
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        return self.out_proj(context)

class CustomGPT2Block(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        self.attn = self._create_custom_attention(config)

    def _create_custom_attention(self, config):
        # Initialize custom attention with original weights
        custom_attn = LastNTokensAttention(config.n_embd, config.n_head)
        original_attn = self.attn  # Get original attention layer

        # Split QKV weights from original model
        q_w, k_w, v_w = original_attn.c_attn.weight.chunk(3, dim=1)
        q_b, k_b, v_b = original_attn.c_attn.bias.chunk(3, dim=0)

        # Copy weights to custom projections
        custom_attn.query_proj.weight.data.copy_(q_w)
        custom_attn.query_proj.bias.data.copy_(q_b)
        custom_attn.key_proj.weight.data.copy_(k_w)
        custom_attn.key_proj.bias.data.copy_(k_b)
        custom_attn.value_proj.weight.data.copy_(v_w)
        custom_attn.value_proj.bias.data.copy_(v_b)

        # Copy output projection
        custom_attn.out_proj.load_state_dict(original_attn.c_proj.state_dict())

        return custom_attn

    def forward(self, hidden_states, layer_past=None, attention_mask=None,
                head_mask=None, use_cache=False, output_attentions=False, **kwargs):
        # LayerNorm before attention
        attn_input = self.ln_1(hidden_states)

        # Custom attention (ignore layer_past and caching)
        attn_output = self.attn(attn_input, attention_mask=attention_mask)

        # Residual connection
        hidden_states = hidden_states + attn_output

        # Feed-forward network
        mlp_input = self.ln_2(hidden_states)
        mlp_output = self.mlp(mlp_input)
        hidden_states = hidden_states + mlp_output

        # Return empty present key/value states to match expected format
        return (hidden_states, (None, None) if use_cache else None, None)

class CustomGPT2LM(GPT2LMHeadModel):
    def __init__(self, config, reference_model):
        super().__init__(config)
        # Disable caching in model config
        self.config.use_cache = False

        # Share embeddings and output layer
        self.transformer.wte = reference_model.transformer.wte
        self.transformer.wpe = reference_model.transformer.wpe
        self.lm_head = reference_model.lm_head

        # Replace blocks with custom attention versions
        self.transformer.h = nn.ModuleList([
            CustomGPT2Block(config) for _ in range(config.n_layer)
        ])

        # Initialize with reference model weights
        self.load_state_dict(reference_model.state_dict(), strict=False)

# Initialize models
print("Initializing models with shared embeddings...")
reference_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
custom_config = GPT2Config.from_pretrained(MODEL_NAME)
custom_model = CustomGPT2LM(custom_config, reference_model).to(DEVICE)

# Freeze non-attention parameters
for name, param in custom_model.named_parameters():
    if "attn" in name or "mlp" in name:  # Unfreeze attention and MLP
        param.requires_grad_(True)

Initializing models with shared embeddings...


In [48]:
########################################
# Cell 5: Dataset and Dataloader (Mock/Example)
########################################

# We will create a small, synthetic dataset to demonstrate the training loop that
# optimizes the custom model to match the reference model's output distribution
# via KL-divergence.

class SyntheticTextDataset(Dataset):
    def __init__(self, tokenizer, texts, max_length=32):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = texts
        self.encodings = []

        for txt in texts:
            enc = tokenizer.encode_plus(
                txt,
                max_length=self.max_length,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            )
            self.encodings.append(enc)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.encodings[idx]
        input_ids = enc["input_ids"].squeeze(0)    # shape: (max_length,)
        attention_mask = enc["attention_mask"].squeeze(0)  # shape: (max_length,)
        return input_ids, attention_mask

# Create some mock data. In real usage, you'd use real text data.
sample_texts = [
    "Hello world, how are you?",
    "The cat sat on the mat.",
    "Artificial intelligence is fascinating.",
    "Short text.",
    "Another example here.",
    "Testing custom attention in LLMs."
]
tokenizer.pad_token = tokenizer.eos_token
dataset = SyntheticTextDataset(tokenizer, sample_texts, max_length=MAX_SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print("Dataset and dataloader ready for demonstration!")

Dataset and dataloader ready for demonstration!


In [68]:
# Replace synthetic data with real text corpus
from datasets import load_dataset

real_dataset = load_dataset("wikitext", "wikitext-103-v1")
train_texts = [txt for txt in real_dataset["train"]["text"] if len(txt) > 100][:1000]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/722k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/655k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [58]:
########################################
# Cell 6: KL-Divergence Loss Setup
########################################

# We define a function that calculates the KL-divergence between the output distributions
# of the reference model and the custom model.
# Typically, for language modeling, we get the logits from each model and then apply
# a cross-entropy or KL divergence measure on the probability distributions.

# def kl_divergence_loss(logits_custom, logits_reference, attention_mask):
#     """
#     Compute KL-divergence between custom model's logits and reference model's logits.
#     logits_custom: (batch_size, seq_length, vocab_size)
#     logits_reference: (batch_size, seq_length, vocab_size)
#     attention_mask: (batch_size, seq_length) with 1 for real tokens, 0 for padding.
#     """
#     # Convert logits to log probabilities
#     log_probs_custom = nn.functional.log_softmax(logits_custom, dim=-1)   # (B, L, V)
#     log_probs_ref = nn.functional.log_softmax(logits_reference, dim=-1)   # (B, L, V)

#     # Convert reference logits to probabilities
#     probs_ref = nn.functional.softmax(logits_reference, dim=-1)           # (B, L, V)

#     # KL(ref || custom) = sum over V [ p_ref(v) * (log p_ref(v) - log p_custom(v)) ]
#     # We'll do it token-wise and then average across tokens in the batch where attention_mask=1
#     kl = probs_ref * (log_probs_ref - log_probs_custom)
#     kl = kl.sum(dim=-1)  # sum across vocab

#     # Now we only consider positions where attention_mask = 1
#     # shape: (batch_size, seq_length)
#     kl = kl * attention_mask

#     # Average across non-masked tokens
#     non_padding_tokens = attention_mask.sum()
#     if non_padding_tokens > 0:
#         kl_mean = kl.sum() / non_padding_tokens
#     else:
#         kl_mean = kl.mean()  # fallback if no tokens

#     return kl_mean

def kl_divergence_loss(logits_custom, logits_ref, mask):
    """Numerically stable KL divergence with masking"""
    assert logits_custom.shape == logits_ref.shape, \
        f"Shape mismatch: {logits_custom.shape} vs {logits_ref.shape}"

    log_probs_custom = F.log_softmax(logits_custom, dim=-1)
    probs_ref = F.softmax(logits_ref.detach(), dim=-1)  # Detach reference model

    # Calculate per-token KL
    kl = (probs_ref * (probs_ref.log() - log_probs_custom)).sum(-1)

    # Apply padding mask and average
    active_tokens = mask.sum()
    return (kl * mask).sum() / active_tokens

In [75]:
########################################
# Cell 7: Training Loop for Custom Model
########################################

# We will train the custom model for a few epochs to minimize KL-divergence from the reference model.
# This is a demonstration, so we keep it simple.

# Verify dimension matching
test_input = torch.randint(0, custom_config.vocab_size, (BATCH_SIZE, MAX_SEQ_LENGTH)).to(DEVICE)
with torch.no_grad():
    ref_logits = reference_model(test_input).logits
    custom_logits = custom_model(test_input).logits
assert ref_logits.shape == custom_logits.shape, \
    f"Shape mismatch: {custom_logits.shape} vs {ref_logits.shape}"

optimizer = optim.AdamW(custom_model.parameters(), lr=LR)

custom_model.train()

for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for step, (input_ids, attention_mask) in enumerate(dataloader):
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        # 1) Get logits from reference model
        with torch.no_grad():
            outputs_ref = reference_model(input_ids=input_ids, attention_mask=attention_mask)
            logits_ref = outputs_ref.logits  # shape: (batch_size, seq_length, vocab_size)

        # 2) Get logits from custom model
        outputs_custom = custom_model(input_ids=input_ids, attention_mask=attention_mask)
        logits_custom = outputs_custom.logits  # shape: (batch_size, seq_length, vocab_size)

        # 3) Compute KL divergence
        loss = kl_divergence_loss(logits_custom, logits_ref, attention_mask)

        # 4) Backprop and update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Avg KL-div Loss: {avg_loss:.4f}")

Epoch 1 | Loss: 1.6116
Epoch 2 | Loss: 1.4442
Epoch 3 | Loss: 1.2661
Epoch 4 | Loss: 1.1446
Epoch 5 | Loss: 1.0045
Epoch 6 | Loss: 0.8343
Epoch 7 | Loss: 0.8361
Epoch 8 | Loss: 0.5799
Epoch 9 | Loss: 0.7221
Epoch 10 | Loss: 0.5836
Epoch 11 | Loss: 0.5864
Epoch 12 | Loss: 0.4689
Epoch 13 | Loss: 0.4585
Epoch 14 | Loss: 0.4321
Epoch 15 | Loss: 0.3520
Epoch 16 | Loss: 0.3692
Epoch 17 | Loss: 0.3412
Epoch 18 | Loss: 0.3057
Epoch 19 | Loss: 0.2865
Epoch 20 | Loss: 0.2843
Epoch 21 | Loss: 0.2935
Epoch 22 | Loss: 0.2570
Epoch 23 | Loss: 0.2425
Epoch 24 | Loss: 0.2479
Epoch 25 | Loss: 0.2256
Epoch 26 | Loss: 0.2145
Epoch 27 | Loss: 0.2236
Epoch 28 | Loss: 0.2088
Epoch 29 | Loss: 0.1948
Epoch 30 | Loss: 0.1930
Epoch 31 | Loss: 0.1997
Epoch 32 | Loss: 0.1838
Epoch 33 | Loss: 0.1871
Epoch 34 | Loss: 0.1796
Epoch 35 | Loss: 0.1790
Epoch 36 | Loss: 0.1746
Epoch 37 | Loss: 0.1681
Epoch 38 | Loss: 0.1757
Epoch 39 | Loss: 0.1723
Epoch 40 | Loss: 0.1558
Epoch 41 | Loss: 0.1699
Epoch 42 | Loss: 0.1589
E

In [81]:
########################################
# Cell 8: Comparison/Testing
########################################

# We can now compare the outputs of the reference model vs. the custom model on some sample prompts.
# For demonstration, we’ll do a simple generation from each.

def generate_text(model, prompt, temperature=0.7, top_k=50, max_length=50):
    """Improved generation with sampling and length normalization"""
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2,
        no_repeat_ngram_size=2,
        use_cache=False  # Explicitly disable caching
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

if SHOW_SAMPLE_OUTPUTS:
    sample_prompts = [
        "Hello, my name is",
        "The meaning of life is",
        "In a shocking turn of events,",
        "The future of artificial intelligence"
    ]

    for prompt in sample_prompts:
        # Generate with different parameters for comparison
        ref_text = generate_text(reference_model, prompt, temperature=0.9)
        custom_text = generate_text(custom_model, prompt, temperature=0.8)

        print(f"\nPrompt: {prompt}")
        print(f"Reference: {ref_text}")
        print(f"Custom:    {custom_text}")
        print("-" * 80)


Prompt: Hello, my name is
Reference: Hello, my name is Robertsson. I am the founder of Inoscular Robotics and also an engineer for Intelligent Silicon Valley (ISV) accelerator program at Stanford University's Artificial Intelligence Research Laboratory in Palo Alto California: https://wwwlandofel
Custom:    Hello, my name is everywhere weirding and the opposite of his subordinates to stop thinking about 3Dressed like mine for me?
Beat. I have you doing so much easier than ever since its eyes as an old motherland in a bit
--------------------------------------------------------------------------------

Prompt: The meaning of life is
Reference: The meaning of life is matter's consciousness. True, you can stop the thought and so on but only insofar as doing it in any way changes your potentiality to do things other people would also like not think about
ARTISTS: No!
Custom:    The meaning of life is gaining access to boost into the couch cushy" ILLPA will a newbies for what, welcome as a

# Conclusion

We have demonstrated:
1) Loading a reference GPT-2 model from Hugging Face.
2) Creating a custom GPT-2-like model with a simplified "last-5-tokens" attention mechanism.
3) Setting up a dataset and training loop that optimizes the custom model to match the reference distribution via KL-divergence.
4) Showed a simple comparison of generated text from both models.

This notebook is purely for demonstration and educational purposes, and many improvements could be made:
- More elaborate data loading
- Proper scheduling, regularization
- Additional GPT-2 intricacies (like caching attention states, etc.)
- More advanced generation strategies (beam search, top-k, top-p, etc.)

But this entire workflow shows how one could begin to experiment with custom attention
mechanisms and align them to a known distribution via KL divergence.