In [None]:
# Install dependencies (if running in a notebook)
!pip install datasets transformers

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

# Import Hugging Face modules
from transformers import (
    AutoTokenizer,
    GPT2LMHeadModel,
    GPT2Config
)
# We also import the original GPT-2 block definition to subclass it.
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

In [None]:
class PerformerAttention(nn.Module):
    """
    A minimal Performer-style Multi-Head Attention module using random feature maps.
    - d_model: total embedding dimension
    - num_heads: number of attention heads
    - n_features: number of random features (sometimes denoted as 'r' or 'm')
    - causal: whether to apply a causal (lower-triangular) mask
    """
    def __init__(self, d_model, num_heads=8, n_features=64, causal=False):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.n_features = n_features
        self.causal = causal

        # Q, K, V projections
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj   = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)

        # We create random weights for the feature maps.
        # shape: (num_heads, head_dim, n_features)
        # Typically Gaussian for the 'favor+' trick.
        self.register_buffer(
            "proj_matrix",
            torch.randn(self.num_heads, self.head_dim, self.n_features)
        )
        # Optionally a random bias
        self.register_buffer(
            "proj_bias",
            2 * torch.pi * torch.rand(self.num_heads, self.n_features)
        )

    def forward(self, x, attention_mask=None):
        """
        x: (batch_size, seq_len, d_model)
        attention_mask: optional; shape (batch_size, seq_len) or broadcastable to (B, 1, L)
              1 for valid tokens, 0 for masked.
        Returns: (batch_size, seq_len, d_model)
        """
        B, L, D = x.shape
        print("Input x range:", x.min().item(), x.max().item())

        # 1. Project to Q, K, V
        q = self.query_proj(x)  # (B, L, d_model)
        k = self.key_proj(x)
        v = self.value_proj(x)

        q_norm = q.detach().norm(dim=-1).mean().item()
        print("Average Q norm:", q_norm)

        # 2. Reshape into multiple heads
        #    (B, L, num_heads, head_dim) => then transpose to (B, num_heads, L, head_dim)
        q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        print("Q range:", q.min().item(), q.max().item())
        print("K range:", k.min().item(), k.max().item())

        # 3. Convert Q, K to random feature space: phi(Q), phi(K)
        #    shape of q_features, k_features => (B, num_heads, L, 2*n_features)
        q_features = self.random_feature_map(q, self.proj_matrix, self.proj_bias)
        k_features = self.random_feature_map(k, self.proj_matrix, self.proj_bias)

        # If we have a padding mask, incorporate it by zeroing out k_features and v
        if attention_mask is not None:
            # print(attention_mask.dim())
            # print(attention_mask.shape)
            if attention_mask.dim() == 2:
                # attention mask is (B, L)

                # Check for zeros
                num_zeros_per_row = (attention_mask == 0).sum(dim=-1)  # shape: (B,)
                fully_zero_rows = (num_zeros_per_row == attention_mask.size(1)).nonzero(as_tuple=True)[0]
                if len(fully_zero_rows) > 0:
                    print("Found rows with all zeros (completely masked) at batch indices:", fully_zero_rows)
                    # Modify mask to allow self-attention for these rows
                    for idx in fully_zero_rows:
                        # Give each token attention to itself
                        diag_mask = torch.eye(L, device=attention_mask.device)
                        attention_mask[idx] = diag_mask[0]  # Use first row of identity matrix

                attention_mask = attention_mask.unsqueeze(1)  # (B,1,L)
            elif attention_mask.dim() == 4:
                # attention mask is (B, 1, L, L)

                mask_3d = attention_mask.squeeze(1)  # shape: (B, L, L)
                sum_along_keys = mask_3d.sum(dim=-1)
                rows_fully_masked = (sum_along_keys == 0).nonzero(as_tuple=False)
                if rows_fully_masked.size(0) > 0:
                    print("Found queries with no valid keys:")
                    # modify k_features and v - create a safe mask by ensuring at least one key is attended to
                    safe_mask = mask_3d.clone()
                    for b in range(safe_mask.size(0)):
                        for q in range(safe_mask.size(1)):
                            if sum_along_keys[b, q] == 0:
                                # For fully masked queries, allow them to attend to themselves
                                # This prevents division by zero later
                                safe_mask[b, q, q] = 1.0

                    # Convert to the format needed (B, 1, L)
                    attention_mask = safe_mask.sum(dim=-1) > 0
                    attention_mask = attention_mask.unsqueeze(1)
                else:
                  attention_mask = attention_mask[:, 0, 0, :]  # shape (B, L)
                  attention_mask = attention_mask.unsqueeze(1)   # => (B, 1, L)
            # shape => (B, 1, L, 1) for broadcast with (B,H,L,D)
            mask_4d = attention_mask.unsqueeze(-1).float()
            k_features = k_features * mask_4d
            v          = v          * mask_4d

        # If causal, zero out positions where j > i
        if self.causal:
            # We can do this by constructing a causal mask of shape (L, L)
            # Then each query index i is only allowed to attend to 0..i
            causal_mask = torch.tril(torch.ones(L, L, device=x.device))
            # We expand to (B, 1, L, L). We'll use it after we get the big mm if needed,
            # but in random-feature attention we typically handle it in a different way:
            #   "causal FAVOR" approach (not shown here in detail).
            # For a minimal approach, we do the standard approach if we revert to "full" matmul:
            # or we can skip it for demonstration. We'll skip a full causal approach for brevity.
            pass

        # 4. Compute "KV" => sum_{time} [k_features(t) * v(t)]
        #    k_features is (B, H, L, F), v is (B, H, L, head_dim)
        #    We want => (B, H, F, head_dim)
        kv = torch.einsum("bhlf,bhld->bhfd", k_features, v)

        # 5. Numerator: q_features @ (k_features^T V)
        #    => (B, H, L, head_dim)
        numerator = torch.einsum("bhlf,bhfd->bhld", q_features, kv)

        # 6. Denominator: q_features @ sum_{time}(k_features)
        #    => shape (B, H, L)
        k_sum = k_features.sum(dim=2)  # (B,H,F)
        denominator = torch.einsum("bhlf,bhf->bhl", q_features, k_sum)
        denominator = denominator.unsqueeze(-1) + 1e-4  # (B,H,L,1)

        with torch.no_grad():
            print("denominator min:", denominator.min().item(),
                  "denominator max:", denominator.max().item())
            # or check for any nans
            if torch.isnan(denominator).any():
                print("NAN in denominator!")


        # 7. Final attention output
        out = numerator / denominator  # (B,H,L,head_dim)

        # Replace any remaining -inf or NaN with zeros
        out = torch.where(torch.isnan(out) | torch.isinf(out), torch.zeros_like(out), out)

        # 8. Recombine heads
        out = out.transpose(1, 2).contiguous().view(B, L, self.num_heads * self.head_dim)
        out = self.out_proj(out)
        return out

    def random_feature_map(self, x, W, b):
        proj = torch.einsum("bhld,hdf->bhlf", x, W)
        proj = proj + b.unsqueeze(0).unsqueeze(2)  # add bias

        # E.g. use elu+1 to keep positive
        out = nn.functional.elu(proj, alpha=1.0) + 1.0
        # Then scale
        out = out * (1.0 / (self.n_features ** 0.5))
        return out
        """
        x: (B, H, L, head_dim)
        W: (H, head_dim, n_features)
        b: (H, n_features)
        Output: (B, H, L, 2*n_features)

        We use a simple FAVOR+ mapping:
        phi(x) = 1/sqrt(n_features) * [ cos(Wx + b), sin(Wx + b) ] * exp(-||x||^2/2)
        """
        B, H, L, D = x.shape
        F = self.n_features

        # (B,H,L,D) x (H,D,F) => (B,H,L,F)
        # We do an einsum or bmm for each head.
        proj = torch.einsum("bhld,hdf->bhlf", x, W)  # shape (B,H,L,F)
        # Add bias
        proj = proj + b.unsqueeze(0).unsqueeze(2)  # (B,H,L,F)

        # Optional scale factor from the Gaussian RFF approach:
        # exp(-||x||^2 / 2) factor
        norm_sq = (x * x).sum(dim=-1, keepdim=True)  # (B,H,L,1)
        exp_term = torch.exp(-0.5 * norm_sq)  # (B,H,L,1)

        # Now form [cos(...), sin(...)]
        cos_proj = torch.cos(proj)
        sin_proj = torch.sin(proj)
        # Concatenate along the feature dimension => (B,H,L, 2F)
        out = torch.cat([cos_proj, sin_proj], dim=-1)

        # Multiply by the exp_term
        out = out * exp_term

        # Finally scale by 1/sqrt(F) or 1/sqrt(2F)?
        # We'll do 1/sqrt(2F) if we treat [cos,sin] as doubling #features.
        scale = 1.0 / ((2 * F) ** 0.5)
        out = out * scale
        return out

In [None]:
# ------------------------------
# Custom GPT-2 Block with our CustomLearnableAttention
# ------------------------------
class CustomGPT2Block(GPT2Block):
    """
    Subclass the original GPT2Block to replace the attention layer with our custom one.
    We also copy the original Q, K, V, and output projection weights so that the custom
    attention begins with a similar behavior to the baseline.
    """
    def __init__(self, config):
        super().__init__(config)
        # Replace the default attention with our custom attention.
        self.attn = self._create_custom_attention(config)

    def _create_custom_attention(self, config):
        # Original GPT2Attention:
        original_attn = self.attn  # the old GPT2Attention instance

        # Create the new kernel attention
        custom_attn = PerformerAttention(
            d_model=config.n_embd,
            num_heads=config.n_head,
            n_features=64
        )

        # We can copy part of the old c_attn weights (which was a big linear that had Q,K,V).
        with torch.no_grad():
            old_w = original_attn.c_attn.weight # 768, 2304
            old_b = original_attn.c_attn.bias # 2304

            old_w = old_w.T

            # old_w has shape (3 * d_model, d_model), corresponding to Q, K, V stacked.
            q_w, k_w, v_w = old_w.chunk(3, dim=0)
            q_b, k_b, v_b = old_b.chunk(3, dim=0)

            # Copy into the new Q, K, V
            custom_attn.query_proj.weight.copy_(q_w)
            custom_attn.query_proj.bias.copy_(q_b)
            custom_attn.key_proj.weight.copy_(k_w)
            custom_attn.key_proj.bias.copy_(k_b)
            custom_attn.value_proj.weight.copy_(v_w)
            custom_attn.value_proj.bias.copy_(v_b)

            # c_proj is the final linear after attention; we can copy that to our out_proj
            custom_attn.out_proj.load_state_dict(original_attn.c_proj.state_dict())

        return custom_attn

    def forward(self, hidden_states, layer_past=None, attention_mask=None,
                head_mask=None, use_cache=False, output_attentions=False, **kwargs):
        # Standard GPT-2 block forward pass with two residual connections.
        attn_input = self.ln_1(hidden_states)
        attn_output = self.attn(attn_input, attention_mask=attention_mask)
        hidden_states = hidden_states + attn_output

        mlp_input = self.ln_2(hidden_states)
        mlp_output = self.mlp(mlp_input)
        hidden_states = hidden_states + mlp_output

        # We do not use caching in our custom model.
        return (hidden_states, None, None)

In [None]:
# ------------------------------
# Custom GPT-2 LM Model that Uses the Custom Blocks
# ------------------------------
class CustomGPT2LM(GPT2LMHeadModel):
    """
    This custom language model replaces each Transformer block with our custom block
    (which uses the learnable attention mask). In addition, we share the word and
    positional embeddings, as well as the LM head, with the reference model.
    """
    def __init__(self, config, reference_model):
        super().__init__(config)
        # Disable caching for generation.
        self.config.use_cache = False

        # Share embeddings and the language model head from the reference model.
        self.transformer.wte = reference_model.transformer.wte
        self.transformer.wpe = reference_model.transformer.wpe
        self.lm_head = reference_model.lm_head

        # Replace all Transformer blocks with our custom blocks.
        self.transformer.h = nn.ModuleList([CustomGPT2Block(config) for _ in range(config.n_layer)])
        # Load weights from the reference model (allowing missing keys since our modules are modified).
        self.load_state_dict(reference_model.state_dict(), strict=False)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # If a 4D mask is passed in (e.g. from GPT‑2 generation), replace it with a 2D mask.
        if attention_mask is not None and attention_mask.dim() == 4:
            attention_mask = (input_ids != tokenizer.pad_token_id).long()
        return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
# ------------------------------
# Prepare the Dataset and Dataloader using wikitext-2
# ------------------------------
# We create a simple Dataset that tokenizes the raw texts from wikitext.
class WikiTextDataset(Dataset):
    def __init__(self, tokenizer, texts, max_length=32):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = texts
        self.encodings = []

        for txt in texts:
            # Tokenize each text and pad/truncate to max_length.
            enc = tokenizer.encode_plus(
                txt,
                max_length=self.max_length,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            )
            self.encodings.append(enc)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.encodings[idx]
        input_ids = enc["input_ids"].squeeze(0)         # shape: (max_length,)
        attention_mask = enc["attention_mask"].squeeze(0)   # shape: (max_length,)
        return input_ids, attention_mask

In [None]:
# ------------------------------
# Hyperparameters and settings
# ------------------------------
MODEL_NAME = "gpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
LR = 6e-5
MAX_SEQ_LENGTH = 128      # maximum sequence length for training examples
NUM_EPOCHS = 3           # For demonstration we use few epochs (use more in practice)
SHOW_SAMPLE_OUTPUTS = True   # Whether to show sample text generations for comparison
GRAD_CLIP = 1.0
L1_COEFF = 1e-5          # Coefficient for L1 penalty on attention mask weights

print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# ------------------------------
# Load the baseline (reference) GPT-2 model and tokenizer.
# This model is used only for generating target logits.
# ------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Ensure the tokenizer has a pad token (set to the EOS token if missing)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Load the wikitext-2 dataset using the Hugging Face datasets library.
from datasets import load_dataset
wikitext_data = load_dataset("wikitext", "wikitext-2-raw-v1")
# For demonstration, we take a small subset of the training split.
train_texts = [txt for txt in wikitext_data["train"]["text"] if len(txt) > 50][:1000]

# Create our dataset and dataloader.
dataset = WikiTextDataset(tokenizer, train_texts, max_length=MAX_SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print("Dataset and dataloader ready for training!")

Dataset and dataloader ready for training!


In [None]:
# ------------------------------
# KL-Divergence Loss Function
# ------------------------------
def kl_divergence_loss(logits_custom, logits_ref, mask):
    """
    Compute a token-wise KL-divergence between the output distributions of the custom model
    and the reference model. The loss is averaged over the active (non-padded) tokens.

    logits_custom: (B, L, V)
    logits_ref:    (B, L, V)
    mask:          (B, L) with 1 for active tokens and 0 for padding.
    """
    log_probs_custom = F.log_softmax(logits_custom, dim=-1)
    # Detach the reference probabilities to avoid backprop into the reference model.
    probs_ref = F.softmax(logits_ref.detach(), dim=-1)
    # Compute the per-token KL divergence.
    kl = (probs_ref * (probs_ref.log() - log_probs_custom)).sum(-1)  # shape: (B, L)
    # Average the loss over the active tokens.
    active_tokens = mask.sum()
    return (kl * mask).sum() / (active_tokens + 1e-8)

In [None]:
# ------------------------------
# Initialize the Custom Model
# ------------------------------
print("Initializing custom model with learnable attention masks...")
# We re-load the reference model so that the custom model can copy its embeddings and head.
reference_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
reference_model.eval()
print(f"Reference model '{MODEL_NAME}' loaded successfully!")

custom_config = GPT2Config.from_pretrained(MODEL_NAME)
custom_model = CustomGPT2LM(custom_config, reference_model).to(DEVICE)

# Freeze parameters that are not part of the custom attention (or other parts we want fixed).
for name, param in custom_model.named_parameters():
    # Here we unfreeze only the parameters that include "attn" in their name.
    if "attn" in name:
        param.requires_grad_(True)
    else:
        param.requires_grad_(False)

# Set the custom model to train mode.
custom_model.train()

# Create the optimizer to update only parameters that require gradients.
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, custom_model.parameters()), lr=LR)

Initializing custom model with learnable attention masks...


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# ------------------------------
# Training Loop
# ------------------------------
print("Starting training loop...")
for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for input_ids, attn_mask in dataloader:
        input_ids, attn_mask = input_ids.to(DEVICE), attn_mask.to(DEVICE)

        with torch.no_grad():
            ref_out = reference_model(input_ids=input_ids, attention_mask=attn_mask)
        ref_logits = ref_out.logits

        out_custom = custom_model(input_ids=input_ids, attention_mask=attn_mask)
        custom_logits = out_custom.logits

        loss = kl_divergence_loss(custom_logits, ref_logits, attn_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} | Avg KL Loss: {total_loss / len(dataloader):.4f}")

Starting training loop...
Input x range: -0.8565549254417419 0.811599850654602
Average Q norm: 1.757869005203247
Q range: -0.3124525547027588 0.3196738362312317
K range: -0.30481231212615967 0.3313223421573639
Found queries with no valid keys:
denominator min: 9.99999993922529e-09 denominator max: 25.773473739624023
Input x range: -1.9046885967254639 1.701981544494629
Average Q norm: 2.799513101577759
Q range: -0.48976975679397583 0.4748806655406952
K range: -0.5300219058990479 0.5951771140098572
Found queries with no valid keys:
denominator min: 9.99999993922529e-09 denominator max: 26.9306583404541
Input x range: -6.16571044921875 2.432818651199341
Average Q norm: 3.577327251434326
Q range: -0.6181554794311523 0.6300800442695618
K range: -0.6583061814308167 0.5700312852859497
Found queries with no valid keys:
denominator min: 9.99999993922529e-09 denominator max: 24.63714027404785
Input x range: -9.290886878967285 3.50590443611145
Average Q norm: 3.0000436305999756
Q range: -0.798006

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Input x range: nan nan
Average Q norm: nan
Q range: nan nan
K range: nan nan
Found queries with no valid keys:
denominator min: nan denominator max: nan
NAN in denominator!
Input x range: nan nan
Average Q norm: nan
Q range: nan nan
K range: nan nan
Found queries with no valid keys:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-43-7ef41e052b43>", line 14, in <cell line: 0>
    out_custom = custom_model(input_ids=input_ids, attention_mask=attn_mask)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_imp

In [None]:
# ------------------------------
# Text Generation Comparison
# ------------------------------
def generate_text(model, prompt, temperature=0.7, top_k=50, max_length=50):
    """
    Generate text using the provided model and prompt.
    Note: We disable caching (use_cache=False) for our custom model.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2,
        no_repeat_ngram_size=2,
        use_cache=False  # Disable caching for compatibility with custom attention.
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

if SHOW_SAMPLE_OUTPUTS:
    sample_prompts = [
        "Hello, my name is",
        "The meaning of life is",
        "In a shocking turn of events,",
        "The future of artificial intelligence"
    ]

    longer_sample_prompts = [
        "As the sun set behind the towering mountains, the weary traveler finally caught sight of the distant village, its warm lights flickering like tiny stars",
        "In the year 2157, humanity had finally perfected interstellar travel. The first colony ship, brimming with hope and thousands of eager settlers",
        "The scientist stared at the glowing vial on the laboratory table, her fingers trembling with anticipation. After years of relentless experimentation",
        "The detective pushed open the heavy oak door, stepping into a room thick with the scent of old books and something more sinister—fear"
    ]

    for prompt in sample_prompts:
        ref_text = generate_text(reference_model, prompt, temperature=0.8)
        custom_text = generate_text(custom_model, prompt, temperature=0.8)
        print(f"\nPrompt: {prompt}")
        print(f"Reference: {ref_text}")
        print(f"Custom:    {custom_text}")
        print("-" * 80)
    print()
    for prompt in longer_sample_prompts:
        ref_text = generate_text(reference_model, prompt, temperature=0.8)
        custom_text = generate_text(custom_model, prompt, temperature=0.8)
        print(f"\nPrompt: {prompt}")
        print(f"Reference: {ref_text}")
        print(f"Custom:    {custom_text}")
        print("-" * 80)

As the alphas approach -infinity, this means the masks are less and less important, and as they approach +infinity, they are more and more important. (The coefficients of the candidate masks are actually the sigmoid of the alphas.)

In [None]:
# Iterate over each row index
print(all_alphas[0].shape)
for row_idx in range(all_alphas[0].shape[0]):
    plt.figure(figsize=(8, 5))

    # Extract the time series for each of the entries
    time_series = np.array([matrix[row_idx] for matrix in all_alphas])

    # Plot each of the n columns over time
    for col_idx in range(all_alphas[0].shape[1]):
        plt.plot(time_series[:, col_idx], label=f'Alpha {col_idx + 1}')

    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(f'Evolution of Alphas for Attention Block {row_idx}')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
# ------------------------------
# Conclusion:
# ------------------------------
# In this notebook we demonstrated how to replace the standard attention mechanism with
# a learnable attention mask that is a weighted combination of several candidate masks.
# By optimizing the weights (with an L1 penalty to encourage sparsity) and minimizing
# the KL divergence between our custom model and a baseline GPT-2 model, we aim to preserve
# generation quality while potentially reducing the computational cost (by selecting only
# the most relevant tokens). This approach can be extended and refined for further research
# into efficient attention mechanisms.

In [None]:
# Recovering alphas from disconnected runtime

import re
import numpy as np

with open("alphas.txt", "r") as f:
    text = f.read()

# Use a regex to find blocks that look like "[[ ... ]]".
# This assumes that each alpha block starts with '[[' and ends with ']]'.
blocks = re.findall(r'\[\[.*?\]\]', text, re.DOTALL)

all_alphas = []
for block in blocks:
    s_clean = block.replace('[', '').replace(']', '')

    try:
        numbers = [float(x) for x in s_clean.split()]
    except ValueError:
        continue

    R, C = 12, 5
    arr = np.array(numbers).reshape((R, C))
    all_alphas.append(arr)

print("Loaded alphas. The shape of the first alpha block is:", all_alphas[0].shape)
print(len(all_alphas))
print(all_alphas[-1])

In [None]:
import torch
import numpy as np

final_alphas_np = all_alphas[-1]

final_alphas_tensor = torch.tensor(final_alphas_np, dtype=torch.float32)

# Iterate through custom_model to set the alpha parameters
module_index = 0
for module in custom_model.modules():
    if isinstance(module, CustomLearnableAttention):
        module.alpha.data.copy_(final_alphas_tensor[module_index])
        module_index += 1

print("Custom model alphas updated.")