In [None]:
# replaced GPT2 attention blocks with "Performer" attention blocks (uses random feature approximation instead of the standard quadratic attention)
# training parameters optimized are weights within the PerformerAttention module: query, key, value projection weights/biases + output weights/biases
  # selectively enabling gradient updates only for parameters in the attention layers

# optimization process
# run both models on the same input
# compute KL divergence between their outputs
# backpropagate through the custom model to update only the attention parameters
# repeat until convergence

In [1]:
!pip install datasets transformers

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import math
import numpy as np

from transformers import (
    AutoTokenizer,
    GPT2LMHeadModel,
    GPT2Config
)

from transformers.models.gpt2.modeling_gpt2 import GPT2Block

In [15]:
class PerformerAttention(nn.Module):
    """
    A minimal Performer-style Multi-Head Attention module using random feature maps.
    - d_model: total embedding dimension
    - num_heads: number of attention heads
    - n_features: number of random features (sometimes denoted as 'r' or 'm')
    """
    def __init__(self, d_model, num_heads=8, n_features=256):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.n_features = n_features

        # Q, K, V projections
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)

        # We create random weights for the feature maps.
        # shape: (num_heads, head_dim, n_features)
        # Typically Gaussian for the 'favor+' trick.
        self.register_buffer(
            "proj_matrix",
            torch.randn(self.num_heads, self.head_dim, self.n_features)
        )
        # Optionally a random bias
        self.register_buffer(
            "proj_bias",
            2 * torch.pi * torch.rand(self.num_heads, self.n_features)
        )

        # Constants for numerical stability
        self.EPS = 1e-8
        self.EPS_NORM = 1e-5
        self.MAX_CLIP = 1e3
        self.FEATURE_SCALE = 1.0 / math.sqrt(2 * self.n_features)

    def forward(self, x, attention_mask=None):
        """
        x: (batch_size, seq_len, d_model)
        attention_mask: optional; shape (batch_size, seq_len) or broadcastable to (B, 1, L)
        1 for valid tokens, 0 for masked.
        Returns: (batch_size, seq_len, d_model)
        """
        # Replace extreme values right at the start
        # x = torch.nan_to_num(x, nan=0.0, posinf=self.MAX_CLIP, neginf=-self.MAX_CLIP)
        # x = torch.clamp(x, -self.MAX_CLIP, self.MAX_CLIP)

        B, L, D = x.shape

        # 1. Project to Q, K, V
        q = self.query_proj(x)  # (B, L, d_model)
        k = self.key_proj(x)
        v = self.value_proj(x)

        # Replace extreme values after projection
        # q = torch.nan_to_num(q, nan=0.0, posinf=self.MAX_CLIP, neginf=-self.MAX_CLIP)
        # k = torch.nan_to_num(k, nan=0.0, posinf=self.MAX_CLIP, neginf=-self.MAX_CLIP)
        # v = torch.nan_to_num(v, nan=0.0, posinf=self.MAX_CLIP, neginf=-self.MAX_CLIP)

        # 2. Reshape into multiple heads
        # (B, L, num_heads, head_dim) => then transpose to (B, num_heads, L, head_dim)
        q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        # # Scale Q and K if norms are too large (using more stable approach)
        # q_norm = torch.norm(q, dim=-1, keepdim=True).clamp_min(self.EPS_NORM)
        # k_norm = torch.norm(k, dim=-1, keepdim=True).clamp_min(self.EPS_NORM)

        # max_norm = 5.0
        # q_norm_factor = torch.where(q_norm > max_norm, max_norm / q_norm, torch.ones_like(q_norm))
        # k_norm_factor = torch.where(k_norm > max_norm, max_norm / k_norm, torch.ones_like(k_norm))

        # q = q * q_norm_factor
        # k = k * k_norm_factor

        q = q / math.sqrt(self.head_dim)
        k = k / math.sqrt(self.head_dim)

        # 3. Convert Q, K to random feature space
        q_features = self.random_feature_map(q)  # (B, H, L, n_features)
        k_features = self.random_feature_map(k)  # (B, H, L, n_features)

        # Handle attention mask
        if attention_mask is not None:
            # Simplify to just apply the mask to k_features
            # attention_mask = attention_mask.unsqueeze(1).unsqueeze(-1)  # (B, 1, L, 1)
            # k_features = k_features * attention_mask

            # Handle attention mask based on its dimensionality
            if attention_mask.dim() == 2:  # (B, L)
                # Convert to 3D mask
                attention_mask = attention_mask.unsqueeze(1)  # (B, 1, L)
            elif attention_mask.dim() == 4:  # (B, 1, L, L)
                # Convert to (B, 1, L) by checking if there are any valid keys for each query
                attention_mask = (attention_mask.sum(dim=-1) > 0).float()  # (B, 1, L)

            # Apply mask to k_features and v (expanded for broadcasting)
            mask_4d = attention_mask.unsqueeze(-1)  # (B, 1, L, 1)
            k_features = k_features * mask_4d

        # 4. Compute linear attention
        # Key step: k_features.sum(dim=2) computes the normalization factor
        kv = torch.einsum("bhlf,bhld->bhfd", k_features, v)
        k_sum = k_features.sum(dim=2)  # (B, H, n_features)

        # 5. Apply single small epsilon in the denominator
        k_sum = k_sum + 1e-6

        # 6. Compute attention output
        numerator = torch.einsum("bhlf,bhfd->bhld", q_features, kv)
        denominator = torch.einsum("bhlf,bhf->bhl", q_features, k_sum).unsqueeze(-1)
        out = numerator / denominator  # (B, H, L, head_dim)

        # Final safety check
        # out = torch.nan_to_num(out, nan=0.0, posinf=self.MAX_CLIP, neginf=-self.MAX_CLIP)
        # out = torch.clamp(out, -self.MAX_CLIP, self.MAX_CLIP)

        # 8. Recombine heads
        out = out.transpose(1, 2).contiguous().view(B, L, self.num_heads * self.head_dim)
        out = self.out_proj(out)

        # Final check - replace any remaining NaNs or Infs
        # out = torch.nan_to_num(out, nan=0.0, posinf=self.MAX_CLIP, neginf=-self.MAX_CLIP)

        return out

    def random_feature_map(self, x):
      """
      FAVOR+ implementation that properly approximates softmax attention
      x: (B, H, L, head_dim)
      """
      # 1. Normalize x for numerical stability
      x_norm = x / math.sqrt(self.head_dim)

      # 2. Project normalized vectors (using orthogonal random features if possible)
      proj = torch.einsum("bhld,hdf->bhlf", x_norm, self.proj_matrix)

      # 3. Add random bias
      proj = proj + self.proj_bias.unsqueeze(0).unsqueeze(2)

      # 4. Apply non-linearity: exp(x) / sqrt(m) for positive features
      # This properly approximates softmax attention
      exp_proj = torch.exp(proj)

      # 5. Create concatenated features
      # The 1/sqrt(m) scaling is crucial for approximation accuracy
      out = exp_proj / math.sqrt(self.n_features)

      return out

In [16]:
# ------------------------------
# Custom GPT-2 Block with our CustomLearnableAttention
# ------------------------------
class CustomGPT2Block(GPT2Block):
    """
    Subclass the original GPT2Block to replace the attention layer with our custom one.
    We also copy the original Q, K, V, and output projection weights so that the custom
    attention begins with a similar behavior to the baseline.
    """
    def __init__(self, config):
        super().__init__(config)
        self.attn = self._create_custom_attention(config) # replace default attention w/ custom attention

    def _create_custom_attention(self, config):
        original_attn = self.attn

        custom_attn = PerformerAttention(
          d_model=config.n_embd,
          num_heads=config.n_head,
          n_features=256
        )

        # Instead of copying weights directly, initialize them with a custom scheme
        with torch.no_grad():
            # Scale the random projections properly for good approximation
            nn.init.orthogonal_(custom_attn.proj_matrix)

            # Get original weights (but apply a scale factor)
            old_w = original_attn.c_attn.weight.T
            old_b = original_attn.c_attn.bias

            # Split Q, K, V
            q_w, k_w, v_w = old_w.chunk(3, dim=0)
            q_b, k_b, v_b = old_b.chunk(3, dim=0)

            # Apply a scaling factor to account for the different mechanism
            scale = 1.0 / math.sqrt(config.n_head)

            # Copy with scaling
            custom_attn.query_proj.weight.copy_(q_w * scale)
            custom_attn.query_proj.bias.copy_(q_b * scale)
            custom_attn.key_proj.weight.copy_(k_w * scale)
            custom_attn.key_proj.bias.copy_(k_b * scale)

            # Value projections can be copied directly
            custom_attn.value_proj.weight.copy_(v_w)
            custom_attn.value_proj.bias.copy_(v_b)

            # Output projection can be copied directly
            custom_attn.out_proj.load_state_dict(original_attn.c_proj.state_dict())

        return custom_attn

        # # We can copy part of the old c_attn weights (which was a big linear that had Q,K,V).
        # with torch.no_grad():
        #     old_w = original_attn.c_attn.weight # 768, 2304
        #     old_b = original_attn.c_attn.bias # 2304

        #     old_w = old_w.T

        #     # old_w has shape (3 * d_model, d_model), corresponding to Q, K, V stacked.
        #     q_w, k_w, v_w = old_w.chunk(3, dim=0)
        #     q_b, k_b, v_b = old_b.chunk(3, dim=0)

        #     # Copy into the new Q, K, V
        #     custom_attn.query_proj.weight.copy_(q_w)
        #     custom_attn.query_proj.bias.copy_(q_b)
        #     custom_attn.key_proj.weight.copy_(k_w)
        #     custom_attn.key_proj.bias.copy_(k_b)
        #     custom_attn.value_proj.weight.copy_(v_w)
        #     custom_attn.value_proj.bias.copy_(v_b)

        #     # c_proj is the final linear after attention; we can copy that to our out_proj
        #     custom_attn.out_proj.load_state_dict(original_attn.c_proj.state_dict())

        # return custom_attn

    def forward(self, hidden_states, layer_past=None, attention_mask=None,
                head_mask=None, use_cache=False, output_attentions=False, **kwargs):
        # Standard GPT-2 block forward pass with two residual connections.
        attn_input = self.ln_1(hidden_states)
        attn_output = self.attn(attn_input, attention_mask=attention_mask)
        hidden_states = hidden_states + attn_output

        mlp_input = self.ln_2(hidden_states)
        mlp_output = self.mlp(mlp_input)
        hidden_states = hidden_states + mlp_output

        # We do not use caching in our custom model.
        return (hidden_states, None, None)

In [17]:
# ------------------------------
# Custom GPT-2 LM Model that Uses the Custom Blocks
# ------------------------------
class CustomGPT2LM(GPT2LMHeadModel):
    """
    This custom language model replaces each Transformer block with our custom block
    (which uses the learnable attention mask). In addition, we share the word and
    positional embeddings, as well as the LM head, with the reference model.
    """
    def __init__(self, config, reference_model):
        super().__init__(config)
        # Disable caching for generation.
        self.config.use_cache = False

        # Share embeddings and the language model head from the reference model.
        self.transformer.wte = reference_model.transformer.wte
        self.transformer.wpe = reference_model.transformer.wpe
        self.lm_head = reference_model.lm_head

        # Replace all Transformer blocks with our custom blocks.
        self.transformer.h = nn.ModuleList([CustomGPT2Block(config) for _ in range(config.n_layer)])
        # Load weights from the reference model (allowing missing keys since our modules are modified).
        self.load_state_dict(reference_model.state_dict(), strict=False)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # If a 4D mask is passed in (e.g. from GPT‑2 generation), replace it with a 2D mask.
        if attention_mask is not None and attention_mask.dim() == 4:
            attention_mask = (input_ids != tokenizer.pad_token_id).long()
        return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [18]:
# ------------------------------
# Prepare the Dataset and Dataloader using wikitext-2
# ------------------------------
# We create a simple Dataset that tokenizes the raw texts from wikitext.
class WikiTextDataset(Dataset):
    def __init__(self, tokenizer, texts, max_length=32):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = texts
        self.encodings = []

        for txt in texts:
            # Tokenize each text and pad/truncate to max_length.
            enc = tokenizer.encode_plus(
                txt,
                max_length=self.max_length,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            )
            self.encodings.append(enc)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.encodings[idx]
        input_ids = enc["input_ids"].squeeze(0)         # shape: (max_length,)
        attention_mask = enc["attention_mask"].squeeze(0)   # shape: (max_length,)
        return input_ids, attention_mask

In [19]:
# ------------------------------
# Hyperparameters and settings
# ------------------------------
MODEL_NAME = "gpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
LR = 6e-5
MAX_SEQ_LENGTH = 128      # maximum sequence length for training examples
NUM_EPOCHS = 50           # For demonstration we use few epochs (use more in practice)
SHOW_SAMPLE_OUTPUTS = True   # Whether to show sample text generations for comparison

print(f"Using device: {DEVICE}")

Using device: cuda


In [20]:
# ------------------------------
# Load the baseline (reference) GPT-2 model and tokenizer.
# This model is used only for generating target logits.
# ------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Ensure the tokenizer has a pad token (set to the EOS token if missing)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [21]:
# Load the wikitext-2 dataset using the Hugging Face datasets library.
from datasets import load_dataset
wikitext_data = load_dataset("wikitext", "wikitext-2-raw-v1")
# For demonstration, we take a small subset of the training split.
train_texts = [txt for txt in wikitext_data["train"]["text"] if len(txt) > 50][:1000]

# Create our dataset and dataloader.
dataset = WikiTextDataset(tokenizer, train_texts, max_length=MAX_SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print("Dataset and dataloader ready for training!")

Dataset and dataloader ready for training!


In [22]:
# ------------------------------
# KL-Divergence Loss Function
# ------------------------------
def kl_divergence_loss(logits_custom, logits_ref, mask):
    """
    Compute a token-wise KL-divergence between the output distributions of the custom model
    and the reference model. The loss is averaged over the active (non-padded) tokens.

    logits_custom: (B, L, V)
    logits_ref:    (B, L, V)
    mask:          (B, L) with 1 for active tokens and 0 for padding.
    """
    log_probs_custom = F.log_softmax(logits_custom, dim=-1)
    # Detach the reference probabilities to avoid backprop into the reference model.
    probs_ref = F.softmax(logits_ref.detach(), dim=-1)
    # Compute the per-token KL divergence.
    kl = (probs_ref * (probs_ref.log() - log_probs_custom)).sum(-1)  # shape: (B, L)
    # Average the loss over the active tokens.
    active_tokens = mask.sum()
    return (kl * mask).sum() / (active_tokens + 1e-8)

In [23]:
# ------------------------------
# Initialize the Custom Model
# ------------------------------
print("Initializing custom model with learnable attention masks...")
# We re-load the reference model so that the custom model can copy its embeddings and head.
reference_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
reference_model.eval()
print(f"Reference model '{MODEL_NAME}' loaded successfully!")

custom_config = GPT2Config.from_pretrained(MODEL_NAME)
custom_model = CustomGPT2LM(custom_config, reference_model).to(DEVICE)

# Freeze parameters that are not part of the custom attention (or other parts we want fixed).
for name, param in custom_model.named_parameters():
    # Here we unfreeze only the parameters that include "attn" in their name.
    if "attn" in name:
        param.requires_grad_(True)
    else:
        param.requires_grad_(False)

# Set the custom model to train mode.
custom_model.train()

# Create the optimizer to update only parameters that require gradients.
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, custom_model.parameters()), lr=LR)

Initializing custom model with learnable attention masks...
Reference model 'gpt2' loaded successfully!


In [24]:
# ------------------------------
# Training Loop
# ------------------------------
print("Starting training loop...")
for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for input_ids, attn_mask in dataloader:
        input_ids, attn_mask = input_ids.to(DEVICE), attn_mask.to(DEVICE)

        with torch.no_grad():
            ref_out = reference_model(input_ids=input_ids, attention_mask=attn_mask)
        ref_logits = ref_out.logits

        out_custom = custom_model(input_ids=input_ids, attention_mask=attn_mask)
        custom_logits = out_custom.logits

        loss = kl_divergence_loss(custom_logits, ref_logits, attn_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} | Avg KL Loss: {total_loss / len(dataloader):.4f}")

Starting training loop...
Epoch 1 | Avg KL Loss: 3.1287
Epoch 2 | Avg KL Loss: 2.9838
Epoch 3 | Avg KL Loss: 2.8739
Epoch 4 | Avg KL Loss: 2.7766
Epoch 5 | Avg KL Loss: 2.7217
Epoch 6 | Avg KL Loss: 2.6723
Epoch 7 | Avg KL Loss: 2.6355
Epoch 8 | Avg KL Loss: 2.6146
Epoch 9 | Avg KL Loss: 2.5496
Epoch 10 | Avg KL Loss: 2.5313
Epoch 11 | Avg KL Loss: 2.5233
Epoch 12 | Avg KL Loss: 2.5008
Epoch 13 | Avg KL Loss: 2.5190
Epoch 14 | Avg KL Loss: 2.4935
Epoch 15 | Avg KL Loss: 2.4824
Epoch 16 | Avg KL Loss: 2.4537
Epoch 17 | Avg KL Loss: 2.4543
Epoch 18 | Avg KL Loss: 2.4342
Epoch 19 | Avg KL Loss: 2.4141
Epoch 20 | Avg KL Loss: 2.4021
Epoch 21 | Avg KL Loss: 2.4035
Epoch 22 | Avg KL Loss: 2.3994
Epoch 23 | Avg KL Loss: 2.3822
Epoch 24 | Avg KL Loss: 2.3959
Epoch 25 | Avg KL Loss: 2.3727
Epoch 26 | Avg KL Loss: 2.3844
Epoch 27 | Avg KL Loss: 2.3513
Epoch 28 | Avg KL Loss: 2.3367
Epoch 29 | Avg KL Loss: 2.3392
Epoch 30 | Avg KL Loss: 2.3456
Epoch 31 | Avg KL Loss: 2.3425
Epoch 32 | Avg KL Loss

In [25]:
# ------------------------------
# Text Generation Comparison
# ------------------------------
def generate_text(model, prompt, temperature=0.7, top_k=50, max_length=50):
    """
    Generate text using the provided model and prompt.
    Note: We disable caching (use_cache=False) for our custom model.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2,
        no_repeat_ngram_size=2,
        use_cache=False  # Disable caching for compatibility with custom attention.
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# if SHOW_SAMPLE_OUTPUTS:
sample_prompts = [
    "Hello, my name is",
    "The meaning of life is",
    "In a shocking turn of events,",
    "The future of artificial intelligence"
]

longer_sample_prompts = [
    "As the sun set behind the towering mountains, the weary traveler finally caught sight of the distant village, its warm lights flickering like tiny stars",
    "In the year 2157, humanity had finally perfected interstellar travel. The first colony ship, brimming with hope and thousands of eager settlers",
    "The scientist stared at the glowing vial on the laboratory table, her fingers trembling with anticipation. After years of relentless experimentation",
    "The detective pushed open the heavy oak door, stepping into a room thick with the scent of old books and something more sinister—fear"
]

for prompt in sample_prompts:
    ref_text = generate_text(reference_model, prompt, temperature=0.8)
    custom_text = generate_text(custom_model, prompt, temperature=0.8)
    print(f"\nPrompt: {prompt}")
    print(f"Reference: {ref_text}")
    print(f"Custom:    {custom_text}")
    print("-" * 80)

for prompt in longer_sample_prompts:
    ref_text = generate_text(reference_model, prompt, temperature=0.8)
    custom_text = generate_text(custom_model, prompt, temperature=0.8)
    print(f"\nPrompt: {prompt}")
    print(f"Reference: {ref_text}")
    print(f"Custom:    {custom_text}")
    print("-" * 80)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Prompt: Hello, my name is
Reference: Hello, my name is Tom and I'm a high school student. If you're unfamiliar with me then please take note of what's happening to us," he said by telephone from his home in Ohio early last month after the class was abruptly canceled for
Custom:    Hello, my name is " when the other major of a lot if she did not so he asked about it's father would be seen as his own. They're actually being shown in an incident with those people was still can't give me to bring
--------------------------------------------------------------------------------

Prompt: The meaning of life is
Reference: The meaning of life is that God has created you to live according as he desires. (Deut 4:15.)
This means not only about living in a world without sin or death, but also with the right kind and loving nature we seek
Custom:    The meaning of life is also the use to be, where you can find out. The reason for an individual human beings and one who are a major problems with no o