# Jupyter Notebook: Custom Attention Optimization

Description:
------------
In this notebook, we will:
1. Load a pre-built language model (LLM).
2. Create a copy of the model architecture but replace its attention mechanism with a simplified one that only attends to the last 5 tokens (instead of all previous tokens).
3. Implement a process to compare the outputs of both models and compute a KL-divergence loss.
4. Optimize the custom model's parameters by minimizing the KL-divergence between the two models’ distributions.
5. Demonstrate how to evaluate and compare both models on sample data.

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [None]:
!pip install datasets
!pip install native-sparse-attention-pytorch

Collecting native-sparse-attention-pytorch
  Downloading native_sparse_attention_pytorch-0.2.0-py3-none-any.whl.metadata (4.8 kB)
Collecting einx>=0.3.0 (from native-sparse-attention-pytorch)
  Downloading einx-0.3.0-py3-none-any.whl.metadata (6.9 kB)
Collecting jaxtyping (from native-sparse-attention-pytorch)
  Downloading jaxtyping-0.3.1-py3-none-any.whl.metadata (7.0 kB)
Collecting local-attention>=1.11.1 (from native-sparse-attention-pytorch)
  Downloading local_attention-1.11.1-py3-none-any.whl.metadata (907 bytes)
Collecting rotary-embedding-torch (from native-sparse-attention-pytorch)
  Downloading rotary_embedding_torch-0.8.6-py3-none-any.whl.metadata (675 bytes)
Collecting hyper-connections>=0.1.8 (from local-attention>=1.11.1->native-sparse-attention-pytorch)
  Downloading hyper_connections-0.1.15-py3-none-any.whl.metadata (5.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.5->native-sparse-attention-pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-n

In [None]:
############## # Code Block 1: Imports & Config ##############
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
from native_sparse_attention_pytorch import SparseAttention


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"
SEQ_LEN = 128
BATCH_SIZE = 4
NUM_HEADS = 4
COMPRESS_RATIO = 0.25
WINDOW_SIZE = 64
NUM_EPOCHS = 5

# Native sparse attention configuration
SPARSE_CONFIG = {
    "dim": None,  # Will be set in the model
    "dim_head": 64,  # Dimension per head
    "heads": NUM_HEADS,
    "sliding_window_size": 2,  # Local attention window
    "compress_block_size": 4,  # Size of blocks to compress
    "selection_block_size": 4,  # Size of blocks to select from
    "num_selected_blocks": 2,  # Number of blocks to select
}

In [None]:
############## # Code Block 2: Sparse Attention Components ##############
class CompressedGlobalAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, compress_ratio):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.compress_ratio = compress_ratio

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)
        self.compression = nn.Linear(embed_dim, 1)
        self.expansion = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        keep_num = max(1, int(T * self.compress_ratio))

        # Token compression
        importance = self.compression(x).squeeze(-1)
        _, keep_idx = torch.topk(importance, k=keep_num, dim=-1)
        x_compressed = torch.gather(x, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))

        # Projections
        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = (
            self.Wk(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )
        V = (
            self.Wv(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )

        # Attention
        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        # Masking
        if attention_mask is not None:
            compressed_mask = torch.gather(attention_mask, 1, keep_idx)
            attn_scores = attn_scores.masked_fill(
                compressed_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        output = output.permute(0, 2, 1, 3).contiguous().view(B, T, D)

        output = self.expansion(output)  # Ensure output has correct embedding dimension
        output = output[:, : x.size(1), :]

        return output


class LocalWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.window_size = window_size

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)

    def create_window_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1
        return mask.unsqueeze(0).unsqueeze(1)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        window_mask = self.create_window_mask(T, x.device)

        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.Wk(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.Wv(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(window_mask == 0, -1e10)

        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(
                attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output.permute(0, 2, 1, 3).contiguous().view(B, T, D)


class HierarchicalSparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, compress_ratio):
        super().__init__()
        self.num_heads = num_heads  # 🔹 Store num_heads
        self.local_attn = LocalWindowAttention(embed_dim, num_heads, window_size)
        self.global_attn = CompressedGlobalAttention(
            embed_dim, num_heads, compress_ratio
        )
        self.gate = nn.Sequential(
            nn.Linear(
                embed_dim, num_heads * 2
            ),  # Ensure output is [batch, seq_len, num_heads * 2]
            nn.Softmax(dim=-1),
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        # Get outputs from local and global attention modules.
        local_out = self.local_attn(x, attention_mask)  # Expected shape: (B, T, D)
        global_out = self.global_attn(x, attention_mask)  # Expected shape: (B, T, D)

        B, T, D = x.size()
        head_dim = D // self.num_heads  # Ensure D is divisible by num_heads

        # Compute gating weights.
        # self.gate should output a tensor of shape (B, T, num_heads*2)
        gate_out = self.gate(x)  # Shape: (B, T, num_heads*2)
        # Reshape to (B, T, num_heads, 2) where last dim holds [local_gate, global_gate]
        gates = gate_out.view(B, T, self.num_heads, 2)
        # Unbind the last dimension into two tensors
        local_gate = gates[..., 0]  # Shape: (B, T, num_heads)
        global_gate = gates[..., 1]  # Shape: (B, T, num_heads)

        # Reshape attention outputs to split heads: (B, T, num_heads, head_dim)
        local_out_heads = local_out.view(B, T, self.num_heads, head_dim)
        global_out_heads = global_out.view(B, T, self.num_heads, head_dim)

        # Ensure the gate tensors have an extra dimension for broadcasting: (B, T, num_heads, 1)
        local_gate = local_gate.unsqueeze(-1)
        global_gate = global_gate.unsqueeze(-1)

        # Element-wise multiply each head output by its corresponding gate weight
        combined = local_out_heads * local_gate + global_out_heads * global_gate
        # Reshape back to (B, T, D)
        combined = combined.view(B, T, D)
        return self.out_proj(combined)

In [None]:
############## # Code Block 3: Custom GPT-2 Model ##############
class SparseGPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.drop = nn.Dropout(config.embd_pdrop)

        # Create native sparse attention layer with correct parameters
        sparse_config = SPARSE_CONFIG.copy()
        sparse_config["dim"] = config.hidden_size
        sparse_config["compress_block_sliding_stride"] = 2
        self.sparse_attn = SparseAttention(**sparse_config)

        self.h = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "attn": self.sparse_attn,
                        "ln_1": nn.LayerNorm(config.hidden_size),
                        "mlp": nn.Sequential(
                            nn.Linear(config.hidden_size, 4 * config.hidden_size),
                            nn.GELU(),
                            nn.Linear(4 * config.hidden_size, config.hidden_size),
                        ),
                        "ln_2": nn.LayerNorm(config.hidden_size),
                    }
                )
                for _ in range(config.num_hidden_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.size()
        pos_ids = torch.arange(T, device=DEVICE).unsqueeze(0)

        x = self.drop(self.wte(input_ids) + self.wpe(pos_ids))

        attn_out = torch.zeros_like(x)

        for block in self.h:
            # Apply layer norm before attention
            normed_x = block["ln_1"](x)
            attention_result = block["attn"](normed_x)

            # Apply sparse attention and handle tuple output
            # attn_output = block["attn"](normed_x)
            if attention_result is not None:
                attn_out = attention_result
            # # If attn_output is a tuple, take the first element (the main output)
            # if isinstance(attn_output, tuple):
            #     attn_out = attn_output[0]
            # else:
            #     attn_out = attn_output

            # Apply mask after attention if provided
            if attention_mask is not None:
                attn_out = attn_out * attention_mask.unsqueeze(-1)

            x = x + attn_out
            x = x + block["mlp"](block["ln_2"](x))

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits  # Return only the logits, not a tuple

In [None]:
############## # Code Block 4: Training Setup ##############
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Reference model
ref_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
ref_model.eval()

# Custom model with native sparse attention
cust_config = GPT2Config.from_pretrained(MODEL_NAME)
cust_model = SparseGPT2(cust_config).to(DEVICE)

# Initialize with pretrained weights
pretrained_state_dict = ref_model.state_dict()
cust_model.load_state_dict(pretrained_state_dict, strict=False)

# Use a lower learning rate for fine-tuning
optimizer = torch.optim.AdamW(cust_model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

In [None]:
############## # Code Block 5: Training Loop ##############
def train_step(batch):
    inputs = batch.to(DEVICE)
    attention_mask = (inputs != tokenizer.pad_token_id).float()

    with torch.no_grad():
        ref_outputs = ref_model(inputs, attention_mask=attention_mask)
        if isinstance(ref_outputs, tuple):
            ref_logits = ref_outputs[0]
        else:
            ref_logits = ref_outputs.logits  # Extract logits from the output object

    cust_outputs = cust_model(inputs, attention_mask=attention_mask)
    if isinstance(cust_outputs, tuple):
        cust_logits = cust_outputs[0]
    else:
        cust_logits = cust_outputs

    # Use KL divergence loss with temperature
    temperature = 1.0
    loss = F.kl_div(
        F.log_softmax(cust_logits / temperature, dim=-1),
        F.softmax(ref_logits / temperature, dim=-1).detach(),
        reduction="batchmean",
    ) * (temperature**2)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(cust_model.parameters(), 1.0)
    optimizer.step()
    return loss.item()


def train_epoch(loader):
    cust_model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        loss = train_step(batch)
        total_loss += loss
    return total_loss / len(loader)

In [None]:
############## # Code Block 6: Generation & Evaluation ##############
def generate(model, tokenizer, prompt, max_length=50, temperature=0.7, top_k=50, device="cuda" if torch.cuda.is_available() else "cpu"):
    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    input_ids = input_ids[:1]  # Ensure we only have one batch dimension

    model.to(device)
    model.eval()

    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs

            # Get logits for the last token
            if logits.dim() == 3:  # Standard shape [batch, seq, vocab]
                next_token_logits = logits[:, -1, :] / temperature
            elif logits.dim() == 2:  # If somehow we got [batch*seq, vocab]
                # We only care about the last token's logits
                next_token_logits = logits[-1:, :] / temperature
            else:
                raise ValueError(f"Unexpected logits shape: {logits.shape}")

            # Apply top-k filtering
            top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)

            # Convert to probabilities
            probs = F.softmax(top_k_logits, dim=-1)

            # Sample next token index from top-k logits
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices.gather(1, next_token_idx)

            # Ensure next_token has shape [1, 1]
            next_token = next_token[-1:, :]  # Take only the last row if needed

            # Concatenate to input_ids
            input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [None]:
############## # Code Block 7: Dataset Preparation ##############
from datasets import load_dataset


class WikiDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.samples = []
        for text in texts:
            # Tokenize each text separately, without adding special tokens
            token_ids = tokenizer.encode(text, add_special_tokens=False)
            # Split token_ids into chunks of length seq_len
            for i in range(0, len(token_ids), seq_len):
                chunk = token_ids[i : i + seq_len]
                # Only add full chunks to avoid very short sequences
                if len(chunk) == seq_len:
                    self.samples.append(torch.tensor(chunk))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# Load a small subset (e.g., 1% of the train split) of WikiText data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:80%]")
texts = dataset["text"]
wiki_dataset = WikiDataset(texts, tokenizer, SEQ_LEN)

# Create a DataLoader for training
train_loader = DataLoader(wiki_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
############## # Code Block 8: Training Execution ##############

for epoch in range(NUM_EPOCHS):
    avg_loss = train_epoch(train_loader)
    scheduler.step()
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

    # Generate sample text after each epoch
    if (epoch + 1) % 2 == 0:  # Generate every 2 epochs
        print("\nGenerating sample text:")
        prompt = "Artificial intelligence"
        print("Reference:", generate(ref_model, tokenizer, prompt, temperature=0.7, top_k=50))
        print("Custom:", generate(cust_model, tokenizer, prompt, temperature=0.7, top_k=50))
        print("\n")

# Final generation comparison
prompt = "Artificial intelligence"
print("\nFinal generation comparison:")
print("Reference:", generate(ref_model, tokenizer, prompt, temperature=0.7, top_k=50))
print("Custom:", generate(cust_model, tokenizer, prompt, temperature=0.7, top_k=50))


Training: 100%|██████████| 2036/2036 [08:34<00:00,  3.96it/s]


Epoch 1/5 - Average Loss: 331.6650


Training: 100%|██████████| 2036/2036 [08:37<00:00,  3.94it/s]


Epoch 2/5 - Average Loss: 242.4077

Generating sample text:
Reference: Artificial intelligence has always been a topic of debate. Many of these debates have been about how to make AI intelligible. The problem is that AI cannot understand abstract concepts. The only way to have a very good understanding of abstract concepts is to have basic knowledge of
Custom: Artificial intelligence "I " to write a "the, " is thought too " or " is " a " " " I " " " " " "" " " " " " "We " " " " "the " " " " " " "




Training: 100%|██████████| 2036/2036 [08:36<00:00,  3.94it/s]


Epoch 3/5 - Average Loss: 207.6536


Training: 100%|██████████| 2036/2036 [08:35<00:00,  3.95it/s]


Epoch 4/5 - Average Loss: 186.9719

Generating sample text:
Reference: Artificial intelligence may not necessarily be a good thing for the American public.

As we've reported elsewhere, there are many positive benefits to AI for the public sector.

For example, many people are happy about the way AI works.

Many
Custom: Artificial intelligence to say a new standard decision to keep their "to-pamarimimimimz, S.Nxdbhhhar's , and (in , S.G. , and (M , and M.ANN




Training: 100%|██████████| 2036/2036 [08:36<00:00,  3.94it/s]


Epoch 5/5 - Average Loss: 175.6799

Final generation comparison:
Reference: Artificial intelligence is the key to the future, but it may also be the key to the future with its ability to detect, investigate and manage complex patterns of action.

Some scientists, for example, have proposed that artificial intelligence could be the next big thing
Custom: Artificial intelligence,, will not pay those that they were the most common people could be understood with other other other two-tetetetetetetetetetetetetetetetetetetetetetetetetetetete


# Conclusion

We have demonstrated:
1) Loading a reference GPT-2 model from Hugging Face.
2) Creating a custom GPT-2-like model with a simplified NSA attention mechanism.
3) Setting up a dataset and training loop that optimizes the custom model to match the reference distribution via KL-divergence.
4) Showed a simple comparison of generated text from both models.

This notebook is purely for demonstration and educational purposes, and many improvements could be made:
- More elaborate data loading
- Proper scheduling, regularization
- Additional GPT-2 intricacies (like caching attention states, etc.)
- More advanced generation strategies (beam search, top-k, top-p, etc.)

But this entire workflow shows how one could begin to experiment with custom attention
mechanisms and align them to a known distribution via KL divergence.