In [1]:
# Install dependencies (if running in a notebook)
!pip install datasets transformers

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

# Import Hugging Face modules
from transformers import (
    AutoTokenizer,
    GPT2LMHeadModel,
    GPT2Config
)
# We also import the original GPT-2 block definition to subclass it.
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

In [3]:
class PerformerAttention(nn.Module):
    """
    A minimal Performer-style Multi-Head Attention module using random feature maps.
    - d_model: total embedding dimension
    - num_heads: number of attention heads
    - n_features: number of random features (sometimes denoted as 'r' or 'm')
    - causal: whether to apply a causal (lower-triangular) mask
    """
    def __init__(self, d_model, num_heads=8, n_features=64, causal=False):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.n_features = n_features
        self.causal = causal

        # Q, K, V projections
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)

        # We create random weights for the feature maps.
        # shape: (num_heads, head_dim, n_features)
        # Typically Gaussian for the 'favor+' trick.
        self.register_buffer(
            "proj_matrix",
            torch.randn(self.num_heads, self.head_dim, self.n_features)
        )
        # Optionally a random bias
        self.register_buffer(
            "proj_bias",
            2 * torch.pi * torch.rand(self.num_heads, self.n_features)
        )

    def forward(self, x, attention_mask=None):
        """
        x: (batch_size, seq_len, d_model)
        attention_mask: optional; shape (batch_size, seq_len) or broadcastable to (B, 1, L)
              1 for valid tokens, 0 for masked.
        Returns: (batch_size, seq_len, d_model)
        """
        # Input validation and NaN handling
        if torch.isnan(x).any():
            # print("Input contains NaN values - replacing with zeros")
            x = torch.nan_to_num(x, nan=0.0)

        if torch.isinf(x).any():
            # print("Input contains Inf values - replacing with large finite values")
            x = torch.nan_to_num(x, posinf=1e6, neginf=-1e6)

        # Log input range for debugging
        input_min = x.min().item()
        input_max = x.max().item()
        if abs(input_max) > 1e3 or abs(input_min) > 1e3:
            # print(f"Warning: Input has extreme values: min={input_min}, max={input_max}")
            # Clip extreme values
            x = torch.clamp(x, -1e3, 1e3)

        B, L, D = x.shape

        # 1. Project to Q, K, V
        q = self.query_proj(x)  # (B, L, d_model)
        k = self.key_proj(x)
        v = self.value_proj(x)

        # Check for NaNs after projection
        for tensor, name in [(q, "q"), (k, "k"), (v, "v")]:
            if torch.isnan(tensor).any():
                # print(f"NaN values in {name} after projection")
                tensor.data = torch.nan_to_num(tensor, nan=0.0)
            if torch.isinf(tensor).any():
                # print(f"Inf values in {name} after projection")
                tensor.data = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6)

        # 2. Reshape into multiple heads
        #    (B, L, num_heads, head_dim) => then transpose to (B, num_heads, L, head_dim)
        q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        # Log norms for debugging
        q_norm = q.detach().norm(dim=-1).mean().item()
        k_norm = k.detach().norm(dim=-1).mean().item()
        if q_norm > 10 or k_norm > 10:
            # print(f"Warning: Large norms - q_norm={q_norm}, k_norm={k_norm}")
            # Apply layer norm-like scaling if norms are too large
            q = q * (5.0 / max(q_norm, 1e-6))
            k = k * (5.0 / max(k_norm, 1e-6))

        # 3. Convert Q, K to random feature space: phi(Q), phi(K)
        #    shape of q_features, k_features => (B, num_heads, L, 2*n_features)
        q_features = self.random_feature_map(q)
        k_features = self.random_feature_map(k)

        # Check for NaNs after feature mapping
        for tensor, name in [(q_features, "q_features"), (k_features, "k_features")]:
            if torch.isnan(tensor).any():
                # print(f"NaN values in {name} after feature mapping")
                tensor.data = torch.nan_to_num(tensor, nan=0.0)
            if torch.isinf(tensor).any():
                # print(f"Inf values in {name} after feature mapping")
                tensor.data = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6)

        # If we have a padding mask, incorporate it by zeroing out k_features and v
        if attention_mask is not None:
            # Handle attention mask based on its dimensionality
            if attention_mask.dim() == 2:  # (B, L)
                # Handle fully masked rows
                num_zeros_per_row = (attention_mask == 0).sum(dim=-1)  # shape: (B,)
                fully_zero_rows = (num_zeros_per_row == attention_mask.size(1)).nonzero(as_tuple=True)[0]

                if len(fully_zero_rows) > 0:
                    # print(f"Found {len(fully_zero_rows)} rows with all zeros in mask")
                    # Modify mask to allow self-attention for these rows
                    for idx in fully_zero_rows:
                        # Give each token attention to itself
                        attention_mask[idx] = torch.eye(L, device=attention_mask.device)[0]

                attention_mask = attention_mask.unsqueeze(1)  # (B, 1, L)

            elif attention_mask.dim() == 4:  # (B, 1, L, L)
                # Convert to the format needed (B, 1, L)
                mask_3d = attention_mask.squeeze(1)  # shape: (B, L, L)
                sum_along_keys = mask_3d.sum(dim=-1)  # (B, L)

                # Handle fully masked queries
                zero_query_mask = (sum_along_keys == 0)
                if zero_query_mask.any():
                    # print(f"Found {zero_query_mask.sum().item()} queries with no valid keys")
                    # For fully masked queries, allow them to attend to themselves
                    safe_mask = mask_3d.clone()

                    # Create indices for batch and query dimensions
                    batch_indices = torch.arange(B, device=mask_3d.device).view(-1, 1).expand(-1, L)
                    query_indices = torch.arange(L, device=mask_3d.device).expand(B, -1)

                    # Set diagonal elements to 1 for rows with all zeros
                    safe_mask[batch_indices[zero_query_mask],
                              query_indices[zero_query_mask],
                              query_indices[zero_query_mask]] = 1.0

                    # Convert to format needed (B, 1, L)
                    attention_mask = (safe_mask.sum(dim=-1) > 0).unsqueeze(1)
                else:
                    attention_mask = attention_mask[:, 0, 0, :].unsqueeze(1)  # (B, 1, L)

            # Apply mask to k_features and v
            mask_4d = attention_mask.unsqueeze(-1).float()  # (B, 1, L, 1)
            k_features = k_features * mask_4d
            v = v * mask_4d

        # Handle causal attention if required
        if self.causal:
            # Create causal mask that will be applied later
            causal_mask = torch.tril(torch.ones(L, L, device=x.device)).unsqueeze(0).unsqueeze(1)  # (1, 1, L, L)
            # We'll apply this mask after computing the attention matrix if needed

        # 4. Compute "KV" => sum_{time} [k_features(t) * v(t)]
        kv = torch.einsum("bhlf,bhld->bhfd", k_features, v)

        # Check for NaNs in kv
        if torch.isnan(kv).any():
            # print("NaN values in kv computation")
            kv = torch.nan_to_num(kv, nan=0.0)
        if torch.isinf(kv).any():
            # print("Inf values in kv computation")
            kv = torch.nan_to_num(kv, posinf=1e6, neginf=-1e6)

        # 5. Numerator: q_features @ (k_features^T V)
        numerator = torch.einsum("bhlf,bhfd->bhld", q_features, kv)

        # Check for NaNs in numerator
        if torch.isnan(numerator).any():
            # print("NaN values in numerator")
            numerator = torch.nan_to_num(numerator, nan=0.0)
        if torch.isinf(numerator).any():
            # print("Inf values in numerator")
            numerator = torch.nan_to_num(numerator, posinf=1e6, neginf=-1e6)

        # 6. Denominator: q_features @ sum_{time}(k_features)
        k_sum = k_features.sum(dim=2)  # (B, H, 2*n_features)

        # Safety checks for the denominator components
        if torch.isnan(k_sum).any() or torch.isinf(k_sum).any():
            # print("NaN or Inf in k_sum")
            k_sum = torch.nan_to_num(k_sum, nan=0.0, posinf=1e6, neginf=-1e6)
            # Add a small epsilon to ensure numerical stability
            k_sum = k_sum + 1e-6

        denominator = torch.einsum("bhlf,bhf->bhl", q_features, k_sum)

        # More safety checks for the denominator
        if torch.isnan(denominator).any() or torch.isinf(denominator).any() or (denominator == 0).any():
            # print("NaN, Inf, or zero in denominator")
            denominator = torch.nan_to_num(denominator, nan=1e-6, posinf=1e6, neginf=-1e6)
            denominator = torch.where(denominator == 0, torch.ones_like(denominator) * 1e-6, denominator)

        denominator = denominator.unsqueeze(-1)  # (B, H, L, 1)

        # Add a larger epsilon to ensure stability
        denominator = denominator + 1e-4

        # Log denominator stats for debugging
        denom_min = denominator.min().item()
        denom_max = denominator.max().item()
        # if denom_min < 1e-5 or denom_max > 1e6:
            # print(f"Warning: Extreme denominator values: min={denom_min}, max={denom_max}")

        # 7. Final attention output
        out = numerator / denominator  # (B, H, L, head_dim)

        # Replace any remaining -inf or NaN with zeros
        if torch.isnan(out).any() or torch.isinf(out).any():
            # print("NaN or Inf in output - replacing with zeros")
            out = torch.where(torch.isnan(out) | torch.isinf(out), torch.zeros_like(out), out)

        # Optional: clip output to reasonable range to prevent downstream NaNs
        out = torch.clamp(out, -1e3, 1e3)

        # 8. Recombine heads
        out = out.transpose(1, 2).contiguous().view(B, L, self.num_heads * self.head_dim)
        out = self.out_proj(out)

        # Final NaN check
        if torch.isnan(out).any():
            # print("NaN in final output - replacing with zeros")
            out = torch.nan_to_num(out, nan=0.0)

        return out

    def random_feature_map(self, x):
        """
        x: (B, H, L, head_dim)
        Output: (B, H, L, 2*n_features)

        Applies FAVOR+ mapping:
        phi(x) = 1/sqrt(n_features) * [ cos(Wx + b), sin(Wx + b) ] * exp(-||x||^2/2)
        """
        # Project input
        proj = torch.einsum("bhld,hdf->bhlf", x, self.proj_matrix)  # (B, H, L, F)

        # Check for NaNs in projection
        if torch.isnan(proj).any() or torch.isinf(proj).any():
            print("NaN or Inf in projection")
            proj = torch.nan_to_num(proj, nan=0.0, posinf=1e6, neginf=-1e6)

        proj = proj + self.proj_bias.unsqueeze(0).unsqueeze(2)  # Add bias

        # Compute the exp(-||x||^2 / 2) factor - with safety checks
        norm_sq = (x * x).sum(dim=-1, keepdim=True)  # (B, H, L, 1)
        # Check if norm_sq has bad values
        if torch.isnan(norm_sq).any() or torch.isinf(norm_sq).any() or (norm_sq < 0).any():
            print("Bad values in norm_sq")
            norm_sq = torch.nan_to_num(norm_sq, nan=0.0, posinf=50.0)  # Cap at a reasonable value
            norm_sq = torch.clamp(norm_sq, 0.0, 50.0)  # Ensure non-negative and not too large

        # Apply exp with safety
        exp_term = torch.exp(-0.5 * norm_sq)  # (B, H, L, 1)

        # Check exp_term for NaNs or very small values
        if torch.isnan(exp_term).any() or torch.isinf(exp_term).any():
            print("NaN or Inf in exp_term")
            exp_term = torch.nan_to_num(exp_term, nan=1.0, posinf=1.0, neginf=1e-10)
            exp_term = torch.clamp(exp_term, 1e-10, 1.0)  # Ensure within reasonable range

        # Compute sine and cosine projections
        # Check for extreme values in proj
        if (proj.abs() > 1e6).any():
            print("Extreme values in projection before trig functions")
            proj = torch.clamp(proj, -1e6, 1e6)

        cos_proj = torch.cos(proj)
        sin_proj = torch.sin(proj)

        # Check for NaNs in trig projections
        if torch.isnan(cos_proj).any() or torch.isnan(sin_proj).any():
            print("NaN in trig projections")
            cos_proj = torch.nan_to_num(cos_proj, nan=1.0)
            sin_proj = torch.nan_to_num(sin_proj, nan=0.0)

        # Concatenate along feature dimension
        out = torch.cat([cos_proj, sin_proj], dim=-1)  # (B, H, L, 2*F)

        # Apply scaling with safety
        scale = 1.0 / ((2 * self.n_features) ** 0.5)
        out = out * scale * exp_term

        # Final safety check
        if torch.isnan(out).any() or torch.isinf(out).any():
            print("NaN or Inf in final feature map output")
            out = torch.nan_to_num(out, nan=0.0, posinf=1.0, neginf=-1.0)
            out = torch.clamp(out, -1.0, 1.0)  # Ensure reasonable range

        return out

    def _debug_avg_output(self, tensor):
        """Helper method to debug average values"""
        avg = tensor.abs().mean().item()
        print(f"Avg value: {avg}")
        if torch.isnan(tensor).any():
            print("Warning: tensor contains NaN values")
            nan_count = torch.isnan(tensor).sum().item()
            print(f"NaN count: {nan_count}/{tensor.numel()}")
        if torch.isinf(tensor).any():
            print("Warning: tensor contains inf values")

        return avg

In [4]:
# ------------------------------
# Custom GPT-2 Block with our CustomLearnableAttention
# ------------------------------
class CustomGPT2Block(GPT2Block):
    """
    Subclass the original GPT2Block to replace the attention layer with our custom one.
    We also copy the original Q, K, V, and output projection weights so that the custom
    attention begins with a similar behavior to the baseline.
    """
    def __init__(self, config):
        super().__init__(config)
        # Replace the default attention with our custom attention.
        self.attn = self._create_custom_attention(config)

    def _create_custom_attention(self, config):
        # Original GPT2Attention:
        original_attn = self.attn  # the old GPT2Attention instance

        # Create the new kernel attention
        custom_attn = PerformerAttention(
            d_model=config.n_embd,
            num_heads=config.n_head,
            n_features=64
        )

        # We can copy part of the old c_attn weights (which was a big linear that had Q,K,V).
        with torch.no_grad():
            old_w = original_attn.c_attn.weight # 768, 2304
            old_b = original_attn.c_attn.bias # 2304

            old_w = old_w.T

            # old_w has shape (3 * d_model, d_model), corresponding to Q, K, V stacked.
            q_w, k_w, v_w = old_w.chunk(3, dim=0)
            q_b, k_b, v_b = old_b.chunk(3, dim=0)

            # Copy into the new Q, K, V
            custom_attn.query_proj.weight.copy_(q_w)
            custom_attn.query_proj.bias.copy_(q_b)
            custom_attn.key_proj.weight.copy_(k_w)
            custom_attn.key_proj.bias.copy_(k_b)
            custom_attn.value_proj.weight.copy_(v_w)
            custom_attn.value_proj.bias.copy_(v_b)

            # c_proj is the final linear after attention; we can copy that to our out_proj
            custom_attn.out_proj.load_state_dict(original_attn.c_proj.state_dict())

        return custom_attn

    def forward(self, hidden_states, layer_past=None, attention_mask=None,
                head_mask=None, use_cache=False, output_attentions=False, **kwargs):
        # Standard GPT-2 block forward pass with two residual connections.
        attn_input = self.ln_1(hidden_states)
        attn_output = self.attn(attn_input, attention_mask=attention_mask)
        hidden_states = hidden_states + attn_output

        mlp_input = self.ln_2(hidden_states)
        mlp_output = self.mlp(mlp_input)
        hidden_states = hidden_states + mlp_output

        # We do not use caching in our custom model.
        return (hidden_states, None, None)

In [5]:
# ------------------------------
# Custom GPT-2 LM Model that Uses the Custom Blocks
# ------------------------------
class CustomGPT2LM(GPT2LMHeadModel):
    """
    This custom language model replaces each Transformer block with our custom block
    (which uses the learnable attention mask). In addition, we share the word and
    positional embeddings, as well as the LM head, with the reference model.
    """
    def __init__(self, config, reference_model):
        super().__init__(config)
        # Disable caching for generation.
        self.config.use_cache = False

        # Share embeddings and the language model head from the reference model.
        self.transformer.wte = reference_model.transformer.wte
        self.transformer.wpe = reference_model.transformer.wpe
        self.lm_head = reference_model.lm_head

        # Replace all Transformer blocks with our custom blocks.
        self.transformer.h = nn.ModuleList([CustomGPT2Block(config) for _ in range(config.n_layer)])
        # Load weights from the reference model (allowing missing keys since our modules are modified).
        self.load_state_dict(reference_model.state_dict(), strict=False)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # If a 4D mask is passed in (e.g. from GPT‑2 generation), replace it with a 2D mask.
        if attention_mask is not None and attention_mask.dim() == 4:
            attention_mask = (input_ids != tokenizer.pad_token_id).long()
        return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [6]:
# ------------------------------
# Prepare the Dataset and Dataloader using wikitext-2
# ------------------------------
# We create a simple Dataset that tokenizes the raw texts from wikitext.
class WikiTextDataset(Dataset):
    def __init__(self, tokenizer, texts, max_length=32):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = texts
        self.encodings = []

        for txt in texts:
            # Tokenize each text and pad/truncate to max_length.
            enc = tokenizer.encode_plus(
                txt,
                max_length=self.max_length,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            )
            self.encodings.append(enc)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.encodings[idx]
        input_ids = enc["input_ids"].squeeze(0)         # shape: (max_length,)
        attention_mask = enc["attention_mask"].squeeze(0)   # shape: (max_length,)
        return input_ids, attention_mask

In [7]:
# ------------------------------
# Hyperparameters and settings
# ------------------------------
MODEL_NAME = "gpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
LR = 6e-5
MAX_SEQ_LENGTH = 128      # maximum sequence length for training examples
NUM_EPOCHS = 50           # For demonstration we use few epochs (use more in practice)
SHOW_SAMPLE_OUTPUTS = True   # Whether to show sample text generations for comparison
GRAD_CLIP = 1.0
L1_COEFF = 1e-5          # Coefficient for L1 penalty on attention mask weights

print(f"Using device: {DEVICE}")

Using device: cuda


In [8]:
# ------------------------------
# Load the baseline (reference) GPT-2 model and tokenizer.
# This model is used only for generating target logits.
# ------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Ensure the tokenizer has a pad token (set to the EOS token if missing)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [9]:
# Load the wikitext-2 dataset using the Hugging Face datasets library.
from datasets import load_dataset
wikitext_data = load_dataset("wikitext", "wikitext-2-raw-v1")
# For demonstration, we take a small subset of the training split.
train_texts = [txt for txt in wikitext_data["train"]["text"] if len(txt) > 50][:1000]

# Create our dataset and dataloader.
dataset = WikiTextDataset(tokenizer, train_texts, max_length=MAX_SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print("Dataset and dataloader ready for training!")

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset and dataloader ready for training!


In [10]:
# ------------------------------
# KL-Divergence Loss Function
# ------------------------------
def kl_divergence_loss(logits_custom, logits_ref, mask):
    """
    Compute a token-wise KL-divergence between the output distributions of the custom model
    and the reference model. The loss is averaged over the active (non-padded) tokens.

    logits_custom: (B, L, V)
    logits_ref:    (B, L, V)
    mask:          (B, L) with 1 for active tokens and 0 for padding.
    """
    log_probs_custom = F.log_softmax(logits_custom, dim=-1)
    # Detach the reference probabilities to avoid backprop into the reference model.
    probs_ref = F.softmax(logits_ref.detach(), dim=-1)
    # Compute the per-token KL divergence.
    kl = (probs_ref * (probs_ref.log() - log_probs_custom)).sum(-1)  # shape: (B, L)
    # Average the loss over the active tokens.
    active_tokens = mask.sum()
    return (kl * mask).sum() / (active_tokens + 1e-8)

In [11]:
# ------------------------------
# Initialize the Custom Model
# ------------------------------
print("Initializing custom model with learnable attention masks...")
# We re-load the reference model so that the custom model can copy its embeddings and head.
reference_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
reference_model.eval()
print(f"Reference model '{MODEL_NAME}' loaded successfully!")

custom_config = GPT2Config.from_pretrained(MODEL_NAME)
custom_model = CustomGPT2LM(custom_config, reference_model).to(DEVICE)

# Freeze parameters that are not part of the custom attention (or other parts we want fixed).
for name, param in custom_model.named_parameters():
    # Here we unfreeze only the parameters that include "attn" in their name.
    if "attn" in name:
        param.requires_grad_(True)
    else:
        param.requires_grad_(False)

# Set the custom model to train mode.
custom_model.train()

# Create the optimizer to update only parameters that require gradients.
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, custom_model.parameters()), lr=LR)

Initializing custom model with learnable attention masks...


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Reference model 'gpt2' loaded successfully!


In [12]:
# ------------------------------
# Training Loop
# ------------------------------
print("Starting training loop...")
for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for input_ids, attn_mask in dataloader:
        input_ids, attn_mask = input_ids.to(DEVICE), attn_mask.to(DEVICE)

        with torch.no_grad():
            ref_out = reference_model(input_ids=input_ids, attention_mask=attn_mask)
        ref_logits = ref_out.logits

        out_custom = custom_model(input_ids=input_ids, attention_mask=attn_mask)
        custom_logits = out_custom.logits

        loss = kl_divergence_loss(custom_logits, ref_logits, attn_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} | Avg KL Loss: {total_loss / len(dataloader):.4f}")

Starting training loop...
Epoch 1 | Avg KL Loss: 3.2546
Epoch 2 | Avg KL Loss: 3.1549
Epoch 3 | Avg KL Loss: 3.0521
Epoch 4 | Avg KL Loss: 2.9681
Epoch 5 | Avg KL Loss: 2.9644
Epoch 6 | Avg KL Loss: 2.9956
Epoch 7 | Avg KL Loss: 2.9608
Epoch 8 | Avg KL Loss: 2.9223
Epoch 9 | Avg KL Loss: 2.9958
Epoch 10 | Avg KL Loss: 2.9535
Epoch 11 | Avg KL Loss: 2.9480
Epoch 12 | Avg KL Loss: 2.9471


KeyboardInterrupt: 

In [13]:
# ------------------------------
# Text Generation Comparison
# ------------------------------
def generate_text(model, prompt, temperature=0.7, top_k=50, max_length=50):
    """
    Generate text using the provided model and prompt.
    Note: We disable caching (use_cache=False) for our custom model.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2,
        no_repeat_ngram_size=2,
        use_cache=False  # Disable caching for compatibility with custom attention.
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# if SHOW_SAMPLE_OUTPUTS:
sample_prompts = [
    "Hello, my name is",
    "The meaning of life is",
    "In a shocking turn of events,",
    "The future of artificial intelligence"
]

longer_sample_prompts = [
    "As the sun set behind the towering mountains, the weary traveler finally caught sight of the distant village, its warm lights flickering like tiny stars",
    "In the year 2157, humanity had finally perfected interstellar travel. The first colony ship, brimming with hope and thousands of eager settlers",
    "The scientist stared at the glowing vial on the laboratory table, her fingers trembling with anticipation. After years of relentless experimentation",
    "The detective pushed open the heavy oak door, stepping into a room thick with the scent of old books and something more sinister—fear"
]

for prompt in sample_prompts:
    ref_text = generate_text(reference_model, prompt, temperature=0.8)
    custom_text = generate_text(custom_model, prompt, temperature=0.8)
    print(f"\nPrompt: {prompt}")
    print(f"Reference: {ref_text}")
    print(f"Custom:    {custom_text}")
    print("-" * 80)
print()
for prompt in longer_sample_prompts:
    ref_text = generate_text(reference_model, prompt, temperature=0.8)
    custom_text = generate_text(custom_model, prompt, temperature=0.8)
    print(f"\nPrompt: {prompt}")
    print(f"Reference: {ref_text}")
    print(f"Custom:    {custom_text}")
    print("-" * 80)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Prompt: Hello, my name is
Reference: Hello, my name is Kipi (I think of you as his friend) and I'm looking for a new job at the company. My only problem with your application was that there were no applications from other employees so we had to come back
Custom:    Hello, my name is to the city of a big and his personal information that he were in an important. The White House (20 on this time during their new equipment or we have been seen as much for free , 5 years ago because it was
--------------------------------------------------------------------------------

Prompt: The meaning of life is
Reference: The meaning of life is that it's going to be good. And you've been there, and now we know a lot more about your past than ever before."
"He wasn't talking here at all," said one woman in the audience who
Custom:    The meaning of life is also known to "I have a lot on this year, and all the U.
After the other for their own , not like J- which was found it's role when times as they 

As the alphas approach -infinity, this means the masks are less and less important, and as they approach +infinity, they are more and more important. (The coefficients of the candidate masks are actually the sigmoid of the alphas.)

In [None]:
# Iterate over each row index
print(all_alphas[0].shape)
for row_idx in range(all_alphas[0].shape[0]):
    plt.figure(figsize=(8, 5))

    # Extract the time series for each of the entries
    time_series = np.array([matrix[row_idx] for matrix in all_alphas])

    # Plot each of the n columns over time
    for col_idx in range(all_alphas[0].shape[1]):
        plt.plot(time_series[:, col_idx], label=f'Alpha {col_idx + 1}')

    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(f'Evolution of Alphas for Attention Block {row_idx}')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
# ------------------------------
# Conclusion:
# ------------------------------
# In this notebook we demonstrated how to replace the standard attention mechanism with
# a learnable attention mask that is a weighted combination of several candidate masks.
# By optimizing the weights (with an L1 penalty to encourage sparsity) and minimizing
# the KL divergence between our custom model and a baseline GPT-2 model, we aim to preserve
# generation quality while potentially reducing the computational cost (by selecting only
# the most relevant tokens). This approach can be extended and refined for further research
# into efficient attention mechanisms.

In [None]:
# Recovering alphas from disconnected runtime

import re
import numpy as np

with open("alphas.txt", "r") as f:
    text = f.read()

# Use a regex to find blocks that look like "[[ ... ]]".
# This assumes that each alpha block starts with '[[' and ends with ']]'.
blocks = re.findall(r'\[\[.*?\]\]', text, re.DOTALL)

all_alphas = []
for block in blocks:
    s_clean = block.replace('[', '').replace(']', '')

    try:
        numbers = [float(x) for x in s_clean.split()]
    except ValueError:
        continue

    R, C = 12, 5
    arr = np.array(numbers).reshape((R, C))
    all_alphas.append(arr)

print("Loaded alphas. The shape of the first alpha block is:", all_alphas[0].shape)
print(len(all_alphas))
print(all_alphas[-1])

In [None]:
import torch
import numpy as np

final_alphas_np = all_alphas[-1]

final_alphas_tensor = torch.tensor(final_alphas_np, dtype=torch.float32)

# Iterate through custom_model to set the alpha parameters
module_index = 0
for module in custom_model.modules():
    if isinstance(module, CustomLearnableAttention):
        module.alpha.data.copy_(final_alphas_tensor[module_index])
        module_index += 1

print("Custom model alphas updated.")