# Jupyter Notebook: Custom Attention Optimization

Description:
------------
In this notebook, we will:
1. Load a pre-built language model (LLM).
2. Create a copy of the model architecture but replace its attention mechanism with a simplified one that only attends to the last 5 tokens (instead of all previous tokens).
3. Implement a process to compare the outputs of both models and compute a KL-divergence loss.
4. Optimize the custom model's parameters by minimizing the KL-divergence between the two models’ distributions.
5. Demonstrate how to evaluate and compare both models on sample data.

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

# Conclusion

We have demonstrated:
1) Loading a reference GPT-2 model from Hugging Face.
2) Creating a custom GPT-2-like model with a simplified "last-5-tokens" attention mechanism.
3) Setting up a dataset and training loop that optimizes the custom model to match the reference distribution via KL-divergence.
4) Showed a simple comparison of generated text from both models.

This notebook is purely for demonstration and educational purposes, and many improvements could be made:
- More elaborate data loading
- Proper scheduling, regularization
- Additional GPT-2 intricacies (like caching attention states, etc.)
- More advanced generation strategies (beam search, top-k, top-p, etc.)

But this entire workflow shows how one could begin to experiment with custom attention
mechanisms and align them to a known distribution via KL divergence.

In [None]:
############## # Code Block 1: Imports & Config ##############
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"
SEQ_LEN = 128
BATCH_SIZE = 4
NUM_HEADS = 4
COMPRESS_RATIO = 0.25
WINDOW_SIZE = 64

############## # Code Block 2: Sparse Attention Components ##############
class CompressedGlobalAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, compress_ratio):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.compress_ratio = compress_ratio

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)
        self.compression = nn.Linear(embed_dim, 1)
        self.expansion = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        keep_num = max(1, int(T * self.compress_ratio))

        # Token compression
        importance = self.compression(x).squeeze(-1)
        _, keep_idx = torch.topk(importance, k=keep_num, dim=-1)
        x_compressed = torch.gather(x, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))

        # Projections
        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.Wk(x_compressed).view(B, keep_num, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.Wv(x_compressed).view(B, keep_num, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Attention
        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        # Masking
        if attention_mask is not None:
            compressed_mask = torch.gather(attention_mask, 1, keep_idx)
            attn_scores = attn_scores.masked_fill(compressed_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10)

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        output = output.permute(0, 2, 1, 3).contiguous().view(B, T, D)

        output = self.expansion(output)  # Ensure output has correct embedding dimension
        output = output[:, :x.size(1), :]

        return output

class LocalWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.window_size = window_size

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)

    def create_window_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size//2)
            end = min(seq_len, i + self.window_size//2 + 1)
            mask[i, start:end] = 1
        return mask.unsqueeze(0).unsqueeze(1)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        window_mask = self.create_window_mask(T, x.device)

        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.Wk(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.Wv(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(window_mask == 0, -1e10)

        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10)

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output.permute(0, 2, 1, 3).contiguous().view(B, T, D)

class HierarchicalSparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, compress_ratio):
        super().__init__()
        self.num_heads = num_heads  # 🔹 Store num_heads
        self.local_attn = LocalWindowAttention(embed_dim, num_heads, window_size)
        self.global_attn = CompressedGlobalAttention(embed_dim, num_heads, compress_ratio)
        self.gate = nn.Sequential(
            nn.Linear(embed_dim, num_heads * 2),  # Ensure output is [batch, seq_len, num_heads * 2]
            nn.Softmax(dim=-1)
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        # Get outputs from local and global attention modules.
        local_out = self.local_attn(x, attention_mask)   # Expected shape: (B, T, D)
        global_out = self.global_attn(x, attention_mask)   # Expected shape: (B, T, D)

        B, T, D = x.size()
        head_dim = D // self.num_heads  # Ensure D is divisible by num_heads

        # Compute gating weights.
        # self.gate should output a tensor of shape (B, T, num_heads*2)
        gate_out = self.gate(x)  # Shape: (B, T, num_heads*2)
        # Reshape to (B, T, num_heads, 2) where last dim holds [local_gate, global_gate]
        gates = gate_out.view(B, T, self.num_heads, 2)
        # Unbind the last dimension into two tensors
        local_gate = gates[..., 0]  # Shape: (B, T, num_heads)
        global_gate = gates[..., 1] # Shape: (B, T, num_heads)

        # Reshape attention outputs to split heads: (B, T, num_heads, head_dim)
        local_out_heads = local_out.view(B, T, self.num_heads, head_dim)
        global_out_heads = global_out.view(B, T, self.num_heads, head_dim)

        # Ensure the gate tensors have an extra dimension for broadcasting: (B, T, num_heads, 1)
        local_gate = local_gate.unsqueeze(-1)
        global_gate = global_gate.unsqueeze(-1)

        # Element-wise multiply each head output by its corresponding gate weight
        combined = local_out_heads * local_gate + global_out_heads * global_gate
        # Reshape back to (B, T, D)
        combined = combined.view(B, T, D)
        return self.out_proj(combined)

############## # Code Block 3: Custom GPT-2 Model ##############
class SparseGPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([nn.ModuleDict({
            'attn': HierarchicalSparseAttention(
                config.hidden_size,
                config.num_attention_heads,
                WINDOW_SIZE,
                COMPRESS_RATIO
            ),
            'ln_1': nn.LayerNorm(config.hidden_size),
            'mlp': nn.Sequential(
                nn.Linear(config.hidden_size, 4 * config.hidden_size),
                nn.GELU(),
                nn.Linear(4 * config.hidden_size, config.hidden_size),
            ),
            'ln_2': nn.LayerNorm(config.hidden_size)
        }) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.size()
        pos_ids = torch.arange(T, device=DEVICE).unsqueeze(0)

        x = self.drop(self.wte(input_ids) + self.wpe(pos_ids))

        for block in self.h:
            attn_out = block['attn'](block['ln_1'](x), attention_mask)
            x = x + attn_out
            x = x + block['mlp'](block['ln_2'](x))

        x = self.ln_f(x)
        return self.lm_head(x)

############## # Code Block 4: Training Setup ##############
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Reference model
ref_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
ref_model.eval()

# Custom model
cust_config = GPT2Config.from_pretrained(MODEL_NAME)
cust_model = SparseGPT2(cust_config).to(DEVICE)
optimizer = torch.optim.AdamW(cust_model.parameters(), lr=1e-4)

############## # Code Block 5: Training Loop ##############
def train_step(batch):
    inputs = batch.to(DEVICE)
    attention_mask = (inputs != tokenizer.pad_token_id).float()

    with torch.no_grad():
        ref_logits = ref_model(inputs, attention_mask=attention_mask)[0]

    cust_logits = cust_model(inputs, attention_mask=attention_mask)[0]

    loss = F.kl_div(
        F.log_softmax(cust_logits, dim=-1),
        F.softmax(ref_logits, dim=-1).detach(),
        reduction='batchmean'
    )

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(cust_model.parameters(), 1.0)
    optimizer.step()
    return loss.item()

def train_epoch(loader):
    cust_model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        loss = train_step(batch)
        total_loss += loss
    return total_loss / len(loader)

############## # Code Block 6: Generation & Evaluation ##############
def generate(model, prompt, max_length=50):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    for _ in range(max_length):
        logits = model(input_ids)[0]  # Extract logits if needed
        next_token = torch.argmax(logits[:, -1], dim=-1)  # next_token shape might be 0-d or 1-d
        next_token = next_token.view(-1, 1)  # Force shape to (B, 1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
    return tokenizer.decode(input_ids[0])

############## # Code Block 7: Dataset Preparation ##############
from datasets import load_dataset

class WikiDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.samples = []
        for text in texts:
            # Tokenize each text separately, without adding special tokens
            token_ids = tokenizer.encode(text, add_special_tokens=False)
            # Split token_ids into chunks of length seq_len
            for i in range(0, len(token_ids), seq_len):
                chunk = token_ids[i:i+seq_len]
                # Only add full chunks to avoid very short sequences
                if len(chunk) == seq_len:
                    self.samples.append(torch.tensor(chunk))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# Load a small subset (e.g., 1% of the train split) of WikiText data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:80%]")
texts = dataset["text"]
wiki_dataset = WikiDataset(texts, tokenizer, SEQ_LEN)

# Create a DataLoader for training
train_loader = DataLoader(wiki_dataset, batch_size=BATCH_SIZE, shuffle=True)


############## # Code Block 8: Training Execution ##############
NUM_EPOCHS = 10  # Adjust number of epochs as needed

for epoch in range(NUM_EPOCHS):
    avg_loss = train_epoch(train_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

# After training, you can generate text using the updated custom model.
prompt = "Artificial intelligence"
print("Reference:", generate(ref_model, prompt))
print("Custom:", generate(cust_model, prompt, max_length=100))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Training: 100%|██████████| 2036/2036 [08:41<00:00,  3.91it/s]


Epoch 1/10 - Average Loss: 15.2755


Training: 100%|██████████| 2036/2036 [08:41<00:00,  3.90it/s]


Epoch 2/10 - Average Loss: 14.9442


Training: 100%|██████████| 2036/2036 [08:41<00:00,  3.90it/s]


Epoch 3/10 - Average Loss: 14.8008


Training: 100%|██████████| 2036/2036 [08:41<00:00,  3.90it/s]


Epoch 4/10 - Average Loss: 14.6991


Training: 100%|██████████| 2036/2036 [08:41<00:00,  3.91it/s]


Epoch 5/10 - Average Loss: 14.5985


Training: 100%|██████████| 2036/2036 [08:40<00:00,  3.91it/s]


Epoch 6/10 - Average Loss: 14.5011


Training: 100%|██████████| 2036/2036 [08:40<00:00,  3.91it/s]


Epoch 7/10 - Average Loss: 14.4091


Training: 100%|██████████| 2036/2036 [08:40<00:00,  3.91it/s]


Epoch 8/10 - Average Loss: 14.3247


Training: 100%|██████████| 2036/2036 [08:40<00:00,  3.91it/s]


Epoch 9/10 - Average Loss: 14.2459


Training: 100%|██████████| 2036/2036 [08:41<00:00,  3.90it/s]


Epoch 10/10 - Average Loss: 14.1772
Reference: Artificial intelligence is a new field of research that has been in the works for a while now. It is a field that has been in the works for a while now. It is a field that has been in the works for a while now. It is a
Custom: Artificial intelligence!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


#TEST

In [None]:
!pip install datasets
!pip install native-sparse-attention-pytorch

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [None]:
############## # Code Block 1: Imports & Config ##############
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
from native_sparse_attention_pytorch import SparseAttention


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"
SEQ_LEN = 128
BATCH_SIZE = 4
NUM_HEADS = 4
COMPRESS_RATIO = 0.25
WINDOW_SIZE = 64
NUM_EPOCHS = 1

# Native sparse attention configuration
SPARSE_CONFIG = {
    "dim": None,  # Will be set in the model
    "dim_head": 64,  # Dimension per head
    "heads": NUM_HEADS,
    "sliding_window_size": 2,  # Local attention window
    "compress_block_size": 4,  # Size of blocks to compress
    "selection_block_size": 4,  # Size of blocks to select from
    "num_selected_blocks": 2,  # Number of blocks to select
}


############## # Code Block 2: Sparse Attention Components ##############
class CompressedGlobalAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, compress_ratio):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.compress_ratio = compress_ratio

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)
        self.compression = nn.Linear(embed_dim, 1)
        self.expansion = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        keep_num = max(1, int(T * self.compress_ratio))

        # Token compression
        importance = self.compression(x).squeeze(-1)
        _, keep_idx = torch.topk(importance, k=keep_num, dim=-1)
        x_compressed = torch.gather(x, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))

        # Projections
        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = (
            self.Wk(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )
        V = (
            self.Wv(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )

        # Attention
        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        # Masking
        if attention_mask is not None:
            compressed_mask = torch.gather(attention_mask, 1, keep_idx)
            attn_scores = attn_scores.masked_fill(
                compressed_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        output = output.permute(0, 2, 1, 3).contiguous().view(B, T, D)

        output = self.expansion(output)  # Ensure output has correct embedding dimension
        output = output[:, : x.size(1), :]

        return output


class LocalWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.window_size = window_size

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)

    def create_window_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1
        return mask.unsqueeze(0).unsqueeze(1)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        window_mask = self.create_window_mask(T, x.device)

        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.Wk(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.Wv(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(window_mask == 0, -1e10)

        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(
                attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output.permute(0, 2, 1, 3).contiguous().view(B, T, D)


class HierarchicalSparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, compress_ratio):
        super().__init__()
        self.num_heads = num_heads  # 🔹 Store num_heads
        self.local_attn = LocalWindowAttention(embed_dim, num_heads, window_size)
        self.global_attn = CompressedGlobalAttention(
            embed_dim, num_heads, compress_ratio
        )
        self.gate = nn.Sequential(
            nn.Linear(
                embed_dim, num_heads * 2
            ),  # Ensure output is [batch, seq_len, num_heads * 2]
            nn.Softmax(dim=-1),
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        # Get outputs from local and global attention modules.
        local_out = self.local_attn(x, attention_mask)  # Expected shape: (B, T, D)
        global_out = self.global_attn(x, attention_mask)  # Expected shape: (B, T, D)

        B, T, D = x.size()
        head_dim = D // self.num_heads  # Ensure D is divisible by num_heads

        # Compute gating weights.
        # self.gate should output a tensor of shape (B, T, num_heads*2)
        gate_out = self.gate(x)  # Shape: (B, T, num_heads*2)
        # Reshape to (B, T, num_heads, 2) where last dim holds [local_gate, global_gate]
        gates = gate_out.view(B, T, self.num_heads, 2)
        # Unbind the last dimension into two tensors
        local_gate = gates[..., 0]  # Shape: (B, T, num_heads)
        global_gate = gates[..., 1]  # Shape: (B, T, num_heads)

        # Reshape attention outputs to split heads: (B, T, num_heads, head_dim)
        local_out_heads = local_out.view(B, T, self.num_heads, head_dim)
        global_out_heads = global_out.view(B, T, self.num_heads, head_dim)

        # Ensure the gate tensors have an extra dimension for broadcasting: (B, T, num_heads, 1)
        local_gate = local_gate.unsqueeze(-1)
        global_gate = global_gate.unsqueeze(-1)

        # Element-wise multiply each head output by its corresponding gate weight
        combined = local_out_heads * local_gate + global_out_heads * global_gate
        # Reshape back to (B, T, D)
        combined = combined.view(B, T, D)
        return self.out_proj(combined)


############## # Code Block 3: Custom GPT-2 Model ##############
class SparseGPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.drop = nn.Dropout(config.embd_pdrop)

        # Create native sparse attention layer with correct parameters
        sparse_config = SPARSE_CONFIG.copy()
        sparse_config["dim"] = config.hidden_size
        self.sparse_attn = SparseAttention(**sparse_config)

        self.h = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "attn": self.sparse_attn,
                        "ln_1": nn.LayerNorm(config.hidden_size),
                        "mlp": nn.Sequential(
                            nn.Linear(config.hidden_size, 4 * config.hidden_size),
                            nn.GELU(),
                            nn.Linear(4 * config.hidden_size, config.hidden_size),
                        ),
                        "ln_2": nn.LayerNorm(config.hidden_size),
                    }
                )
                for _ in range(config.num_hidden_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.size()
        pos_ids = torch.arange(T, device=DEVICE).unsqueeze(0)

        x = self.drop(self.wte(input_ids) + self.wpe(pos_ids))

        for block in self.h:
            # Apply layer norm before attention
            normed_x = block["ln_1"](x)

            # Apply sparse attention and handle tuple output
            attn_output = block["attn"](normed_x)
            # If attn_output is a tuple, take the first element (the main output)
            if isinstance(attn_output, tuple):
                attn_out = attn_output[0]
            else:
                attn_out = attn_output

            # Apply mask after attention if provided
            if attention_mask is not None:
                attn_out = attn_out * attention_mask.unsqueeze(-1)

            x = x + attn_out
            x = x + block["mlp"](block["ln_2"](x))

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits  # Return only the logits, not a tuple


############## # Code Block 4: Training Setup ##############
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Reference model
ref_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
ref_model.eval()

# Custom model with native sparse attention
cust_config = GPT2Config.from_pretrained(MODEL_NAME)
cust_model = SparseGPT2(cust_config).to(DEVICE)

# Initialize with pretrained weights
pretrained_state_dict = ref_model.state_dict()
cust_model.load_state_dict(pretrained_state_dict, strict=False)

# Use a lower learning rate for fine-tuning
optimizer = torch.optim.AdamW(cust_model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)


############## # Code Block 5: Training Loop ##############
def train_step(batch):
    inputs = batch.to(DEVICE)
    attention_mask = (inputs != tokenizer.pad_token_id).float()

    with torch.no_grad():
        ref_outputs = ref_model(inputs, attention_mask=attention_mask)
        if isinstance(ref_outputs, tuple):
            ref_logits = ref_outputs[0]
        else:
            ref_logits = ref_outputs.logits  # Extract logits from the output object

    cust_outputs = cust_model(inputs, attention_mask=attention_mask)
    if isinstance(cust_outputs, tuple):
        cust_logits = cust_outputs[0]
    else:
        cust_logits = cust_outputs

    # Use KL divergence loss with temperature
    temperature = 1.0
    loss = F.kl_div(
        F.log_softmax(cust_logits / temperature, dim=-1),
        F.softmax(ref_logits / temperature, dim=-1).detach(),
        reduction="batchmean",
    ) * (temperature**2)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(cust_model.parameters(), 1.0)
    optimizer.step()
    return loss.item()


def train_epoch(loader):
    cust_model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        loss = train_step(batch)
        total_loss += loss
    return total_loss / len(loader)


############## # Code Block 6: Generation & Evaluation ##############
def generate(model, prompt, max_length=50, temperature=0.7, top_k=50):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)

    for _ in range(max_length):
        # Get model predictions
        with torch.no_grad():
            outputs = model(input_ids)
            # Ensure we have the logits tensor
            if isinstance(outputs, tuple):
                logits = outputs[0]
            else:
                logits = outputs

            next_token_logits = logits[:, -1, :] / temperature

            # Apply top-k filtering
            top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)

            # Apply softmax to get probabilities
            probs = F.softmax(top_k_logits, dim=-1)

            # Sample from the filtered distribution
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices.gather(1, next_token_idx)

            # Append the next token to the sequence
            input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


############## # Code Block 7: Dataset Preparation ##############
from datasets import load_dataset


class WikiDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.samples = []
        for text in texts:
            # Tokenize each text separately, without adding special tokens
            token_ids = tokenizer.encode(text, add_special_tokens=False)
            # Split token_ids into chunks of length seq_len
            for i in range(0, len(token_ids), seq_len):
                chunk = token_ids[i : i + seq_len]
                # Only add full chunks to avoid very short sequences
                if len(chunk) == seq_len:
                    self.samples.append(torch.tensor(chunk))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# Load a small subset (e.g., 1% of the train split) of WikiText data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:80%]")
texts = dataset["text"]
wiki_dataset = WikiDataset(texts, tokenizer, SEQ_LEN)

# Create a DataLoader for training
train_loader = DataLoader(wiki_dataset, batch_size=BATCH_SIZE, shuffle=True)


############## # Code Block 8: Training Execution ##############

for epoch in range(NUM_EPOCHS):
    avg_loss = train_epoch(train_loader)
    scheduler.step()
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

    # Generate sample text after each epoch
    if (epoch + 1) % 2 == 0:  # Generate every 2 epochs
        print("\nGenerating sample text:")
        prompt = "Artificial intelligence"
        print("Reference:", generate(ref_model, prompt, temperature=0.7, top_k=50))
        print("Custom:", generate(cust_model, prompt, temperature=0.7, top_k=50))
        print("\n")

# Final generation comparison
prompt = "Artificial intelligence"
print("\nFinal generation comparison:")
print("Reference:", generate(ref_model, prompt, temperature=0.7, top_k=50))
print("Custom:", generate(cust_model, prompt, temperature=0.7, top_k=50))


Training: 100%|██████████| 2036/2036 [07:55<00:00,  4.28it/s]


Epoch 1/1 - Average Loss: 332.1119

Final generation comparison:


TypeError: tuple indices must be integers or slices, not tuple