# Jupyter Notebook: Custom Attention Optimization

Description:
------------
In this notebook, we will:
1. Load a pre-built language model (LLM).
2. Create a copy of the model architecture but replace its attention mechanism with a simplified one that only attends to the last 5 tokens (instead of all previous tokens).
3. Implement a process to compare the outputs of both models and compute a KL-divergence loss.
4. Optimize the custom model's parameters by minimizing the KL-divergence between the two models’ distributions.
5. Demonstrate how to evaluate and compare both models on sample data.

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install datasets
!pip install native-sparse-attention-pytorch

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [None]:
############## # Code Block 1: Imports & Config ##############
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
from native_sparse_attention_pytorch import SparseAttention
import time
import psutil
import os


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"
SEQ_LEN = 128
BATCH_SIZE = 4
NUM_HEADS = 4
COMPRESS_RATIO = 0.25
WINDOW_SIZE = 64
NUM_EPOCHS = 10

# Native sparse attention configuration
SPARSE_CONFIG = {
    "dim": None,  # Will be set in the model
    "dim_head": 64,  # Dimension per head
    "heads": NUM_HEADS,
    "sliding_window_size": 2,  # Local attention window
    "compress_block_size": 4,  # Size of blocks to compress
    "selection_block_size": 4,  # Size of blocks to select from
    "num_selected_blocks": 2,  # Number of blocks to select
}

In [None]:
############## # Code Block 2: Sparse Attention Components ##############
class CompressedGlobalAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, compress_ratio):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.compress_ratio = compress_ratio

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)
        self.compression = nn.Linear(embed_dim, 1)
        self.expansion = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        keep_num = max(1, int(T * self.compress_ratio))

        # Token compression
        importance = self.compression(x).squeeze(-1)
        _, keep_idx = torch.topk(importance, k=keep_num, dim=-1)
        x_compressed = torch.gather(x, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))

        # Projections
        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = (
            self.Wk(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )
        V = (
            self.Wv(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )

        # Attention
        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        # Masking
        if attention_mask is not None:
            compressed_mask = torch.gather(attention_mask, 1, keep_idx)
            attn_scores = attn_scores.masked_fill(
                compressed_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        output = output.permute(0, 2, 1, 3).contiguous().view(B, T, D)

        output = self.expansion(output)  # Ensure output has correct embedding dimension
        output = output[:, : x.size(1), :]

        return output


class LocalWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.window_size = window_size

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)

    def create_window_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1
        return mask.unsqueeze(0).unsqueeze(1)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        window_mask = self.create_window_mask(T, x.device)

        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.Wk(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.Wv(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(window_mask == 0, -1e10)

        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(
                attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output.permute(0, 2, 1, 3).contiguous().view(B, T, D)


class HierarchicalSparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, compress_ratio):
        super().__init__()
        self.num_heads = num_heads  # 🔹 Store num_heads
        self.local_attn = LocalWindowAttention(embed_dim, num_heads, window_size)
        self.global_attn = CompressedGlobalAttention(
            embed_dim, num_heads, compress_ratio
        )
        self.gate = nn.Sequential(
            nn.Linear(
                embed_dim, num_heads * 2
            ),  # Ensure output is [batch, seq_len, num_heads * 2]
            nn.Softmax(dim=-1),
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        # Get outputs from local and global attention modules.
        local_out = self.local_attn(x, attention_mask)  # Expected shape: (B, T, D)
        global_out = self.global_attn(x, attention_mask)  # Expected shape: (B, T, D)

        B, T, D = x.size()
        head_dim = D // self.num_heads  # Ensure D is divisible by num_heads

        # Compute gating weights.
        # self.gate should output a tensor of shape (B, T, num_heads*2)
        gate_out = self.gate(x)  # Shape: (B, T, num_heads*2)
        # Reshape to (B, T, num_heads, 2) where last dim holds [local_gate, global_gate]
        gates = gate_out.view(B, T, self.num_heads, 2)
        # Unbind the last dimension into two tensors
        local_gate = gates[..., 0]  # Shape: (B, T, num_heads)
        global_gate = gates[..., 1]  # Shape: (B, T, num_heads)

        # Reshape attention outputs to split heads: (B, T, num_heads, head_dim)
        local_out_heads = local_out.view(B, T, self.num_heads, head_dim)
        global_out_heads = global_out.view(B, T, self.num_heads, head_dim)

        # Ensure the gate tensors have an extra dimension for broadcasting: (B, T, num_heads, 1)
        local_gate = local_gate.unsqueeze(-1)
        global_gate = global_gate.unsqueeze(-1)

        # Element-wise multiply each head output by its corresponding gate weight
        combined = local_out_heads * local_gate + global_out_heads * global_gate
        # Reshape back to (B, T, D)
        combined = combined.view(B, T, D)
        return self.out_proj(combined)

In [None]:
############## # Code Block 3: Custom GPT-2 Model ##############
class SparseGPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.drop = nn.Dropout(config.embd_pdrop)

        # Create native sparse attention layer with correct parameters
        sparse_config = SPARSE_CONFIG.copy()
        sparse_config["dim"] = config.hidden_size
        sparse_config["compress_block_sliding_stride"] = 2
        self.sparse_attn = SparseAttention(**sparse_config)

        self.h = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "attn": self.sparse_attn,
                        "ln_1": nn.LayerNorm(config.hidden_size),
                        "mlp": nn.Sequential(
                            nn.Linear(config.hidden_size, 4 * config.hidden_size),
                            nn.GELU(),
                            nn.Linear(4 * config.hidden_size, config.hidden_size),
                        ),
                        "ln_2": nn.LayerNorm(config.hidden_size),
                    }
                )
                for _ in range(config.num_hidden_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.size()
        pos_ids = torch.arange(T, device=DEVICE).unsqueeze(0)

        x = self.drop(self.wte(input_ids) + self.wpe(pos_ids))

        attn_out = torch.zeros_like(x)

        for block in self.h:
            # Apply layer norm before attention
            normed_x = block["ln_1"](x)
            attention_result = block["attn"](normed_x)

            # Apply sparse attention and handle tuple output
            # attn_output = block["attn"](normed_x)
            if attention_result is not None:
                attn_out = attention_result
            # # If attn_output is a tuple, take the first element (the main output)
            # if isinstance(attn_output, tuple):
            #     attn_out = attn_output[0]
            # else:
            #     attn_out = attn_output

            # Apply mask after attention if provided
            if attention_mask is not None:
                attn_out = attn_out * attention_mask.unsqueeze(-1)

            x = x + attn_out
            x = x + block["mlp"](block["ln_2"](x))

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits  # Return only the logits, not a tuple


############## # Code Block 4: Training Setup ##############
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Reference model
ref_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
ref_model.eval()

# Custom model with native sparse attention
cust_config = GPT2Config.from_pretrained(MODEL_NAME)
cust_model = SparseGPT2(cust_config).to(DEVICE)

# Initialize with pretrained weights
pretrained_state_dict = ref_model.state_dict()
cust_model.load_state_dict(pretrained_state_dict, strict=False)

# Use a lower learning rate for fine-tuning
optimizer = torch.optim.AdamW(cust_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)


In [None]:
############## # Code Block 5: Training Loop ##############
class UsageTracker:
    def __init__(self):
        self.start_time = None
        self.start_cpu_time = None
        self.train_stats = {
            "epoch": [], "step": [],
            "wall_time_s": [], "cpu_time_s": [],
            "cpu_mem_mb": [], "gpu_alloc_mb": [], "gpu_reserved_mb": []
        }
        self.infer_stats = {
            "max_token_length": [],
            "wall_time_s": [], "cpu_time_s": [],
            "cpu_mem_mb": [], "gpu_alloc_mb": [], "gpu_reserved_mb": []
        }

    def start_tracking(self):
        self.start_time = time.time()
        self.start_cpu_time = time.process_time()

    def stop_tracking(self, is_training, epoch=None, step=None, max_token_length=None):
        if self.start_time is None or self.start_cpu_time is None:
            raise ValueError("Tracking not started. Call start_tracking() first.")

        end_time = time.time()
        end_cpu_time = time.process_time()

        proc = psutil.Process(os.getpid())
        cpu_mem = proc.memory_info().rss / (1024**2)
        if torch.cuda.is_available():
            gpu_alloc    = torch.cuda.memory_allocated(DEVICE) / (1024**2)
            gpu_reserved = torch.cuda.memory_reserved(DEVICE) / (1024**2)
        else:
            gpu_alloc = gpu_reserved = 0.0

        obj = self.train_stats if is_training else self.infer_stats
        obj["wall_time_s"].append(end_time - self.start_time)
        obj["cpu_time_s"].append(end_cpu_time - self.start_cpu_time)
        obj["cpu_mem_mb"].append(cpu_mem)
        obj["gpu_alloc_mb"].append(gpu_alloc)
        obj["gpu_reserved_mb"].append(gpu_reserved)
        if is_training:
            obj["epoch"].append(epoch + 1)
            obj["step"].append(step + 1)
        else:
            obj["max_token_length"].append(max_token_length)

        self.start_time = None
        self.start_cpu_time = None

    def plot(self):
        df_train = pd.DataFrame(self.train_stats)
        df_inf   = pd.DataFrame(self.infer_stats)

        # --- Training plots (line plots vs. step) ---
        for metric in ["wall_time_s", "cpu_time_s", "cpu_mem_mb", "gpu_alloc_mb", "gpu_reserved_mb"]:
            plt.figure(figsize=(8, 4))
            plt.plot(df_train["step"], df_train[metric])
            plt.xlabel("Training Step")
            plt.ylabel(metric.replace("_", " ").title())
            plt.title(f"{metric.replace('_', ' ').title()} During Training")
            plt.show()

        # # --- Inference plots (bar plots vs. call index) ---
        # for metric in ["wall_time_s", "cpu_time_s", "cpu_mem_mb", "gpu_alloc_mb", "gpu_reserved_mb"]:
        #     plt.figure(figsize=(8, 4))
        #     plt.bar(range(len(df_inf)), df_inf[metric])
        #     plt.xlabel("Inference Call #")
        #     plt.ylabel(metric.replace("_", " ").title())
        #     plt.title(f"{metric.replace('_', ' ').title()} Per Inference")
        #     plt.show()

        prompt_lengths = self.infer_stats["max_token_length"]

        # --- Inference plots (bar plots vs. token length) ---
        for metric in ["wall_time_s", "cpu_time_s", "cpu_mem_mb", "gpu_alloc_mb", "gpu_reserved_mb"]:
            plt.figure(figsize=(8, 4))
            plt.scatter(prompt_lengths, self.infer_stats[metric])
            plt.xlabel("Prompt Token Length")
            plt.ylabel(metric.replace("_", " ").title())
            plt.title(f"{metric.replace('_', ' ').title()} vs. Prompt Length")
            plt.show()

def kl_divergence_loss(logits_custom, logits_ref, mask, temperature):
    # 1) soften both distributions by T
    logp_c = F.log_softmax(logits_custom / temperature, dim=-1)
    p_r   = F.softmax(   logits_ref.detach() / temperature, dim=-1)

    # 2) per‑token KL
    kl = (p_r * (p_r.log() - logp_c)).sum(-1)  # (B, L)

    # 3) average over real tokens and re‑scale by T^2
    return (kl * mask).sum() / mask.sum() * (temperature * temperature)

def train_step(batch, epoch, step):
    tracker.start_tracking()

    inputs = batch.to(DEVICE)
    attention_mask = (inputs != tokenizer.pad_token_id).float()

    with torch.no_grad():
        ref_outputs = ref_model(inputs, attention_mask=attention_mask)
        if isinstance(ref_outputs, tuple):
            ref_logits = ref_outputs[0]
        else:
            ref_logits = ref_outputs.logits  # Extract logits from the output object

    cust_outputs = cust_model(inputs, attention_mask=attention_mask)
    if isinstance(cust_outputs, tuple):
        cust_logits = cust_outputs[0]
    else:
        cust_logits = cust_outputs

    # Use KL divergence loss with temperature
    temperature = 0.7
    loss = F.kl_div(
        F.log_softmax(cust_logits / temperature, dim=-1),
        F.softmax(ref_logits / temperature, dim=-1).detach(),
        reduction="sum"
    ) * (temperature**2)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(cust_model.parameters(), 1.0)
    optimizer.step()

    tracker.stop_tracking(is_training=True, epoch=epoch, step=step)

    return loss.item()


def train_epoch(loader, epoch):
    cust_model.train()
    total_loss = 0
    loss_vals = []
    for i, batch in tqdm(enumerate(loader), desc="Training"):
        loss = train_step(batch, epoch, i)
        total_loss += loss
        loss_vals.append(loss)
    return total_loss / len(loader), loss_vals


In [None]:
############## # Code Block 6: Generation & Evaluation ##############
def generate(model, tokenizer, prompt, max_length=50, temperature=0.7, top_k=50, device="cuda" if torch.cuda.is_available() else "cpu"):
    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    input_ids = input_ids[:1]  # Ensure we only have one batch dimension

    model.to(device)
    model.eval()

    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs

            # Get logits for the last token
            if logits.dim() == 3:  # Standard shape [batch, seq, vocab]
                next_token_logits = logits[:, -1, :] / temperature
            elif logits.dim() == 2:  # If somehow we got [batch*seq, vocab]
                # We only care about the last token's logits
                next_token_logits = logits[-1:, :] / temperature
            else:
                raise ValueError(f"Unexpected logits shape: {logits.shape}")

            # Apply top-k filtering
            top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)

            # Convert to probabilities
            probs = F.softmax(top_k_logits, dim=-1)

            # Sample next token index from top-k logits
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices.gather(1, next_token_idx)

            # Ensure next_token has shape [1, 1]
            next_token = next_token[-1:, :]  # Take only the last row if needed

            # Concatenate to input_ids
            input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


In [None]:
############## # Code Block 7: Dataset Preparation ##############
from datasets import load_dataset


class WikiDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.samples = []
        for text in texts:
            # Tokenize each text separately, without adding special tokens
            token_ids = tokenizer.encode(text, add_special_tokens=False)
            # Split token_ids into chunks of length seq_len
            for i in range(0, len(token_ids), seq_len):
                chunk = token_ids[i : i + seq_len]
                # Only add full chunks to avoid very short sequences
                if len(chunk) == seq_len:
                    self.samples.append(torch.tensor(chunk))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# Load a small subset (e.g., 1% of the train split) of WikiText data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100%]")
texts = dataset["text"]
wiki_dataset = WikiDataset(texts, tokenizer, SEQ_LEN)

# Create a DataLoader for training
train_loader = DataLoader(wiki_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
############## # Code Block 8: Training Execution ##############
tracker = UsageTracker()

CHECKPOINT_DIR = '/content/drive/MyDrive/attention_optimization/model_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

for epoch in range(NUM_EPOCHS):

    avg_loss, loss_vals = train_epoch(train_loader, epoch)
    scheduler.step()

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

    # Generate sample text after each epoch
    if (epoch + 1) % 2 == 0:  # Generate every 2 epochs
        print("\nGenerating sample text:")
        prompt = "Artificial intelligence"
        print("Reference:", generate(ref_model, tokenizer, prompt, temperature=0.7, top_k=50))
        print("Custom:", generate(cust_model, tokenizer, prompt, temperature=0.7, top_k=50))
        print("\n")

    # checkpoint model vars
    checkpoint = {
        'epoch': epoch + 1, # next epoch to start from
        'model_state_dict': cust_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_vals': loss_vals,
        # 'val_loss_vals': val_loss_vals,
        # 'all_alphas': all_alphas
    }
    torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'ckpt_epoch_{epoch+1}.pth'))
    print(f"Saved checkpoint for epoch {epoch+1}")

# Final generation comparison
prompt = "Artificial intelligence"
print("\nFinal generation comparison:")
print("Reference:", generate(ref_model, tokenizer, prompt, temperature=0.7, top_k=50))
print("Custom:", generate(cust_model, tokenizer, prompt, temperature=0.7, top_k=50))


Training: 2571it [07:00,  6.11it/s]


Epoch 1/10 - Average Loss: 715.3863
Saved checkpoint for epoch 1


Training: 2571it [07:02,  6.09it/s]


Epoch 2/10 - Average Loss: 481.0767

Generating sample text:
Reference: Artificial intelligence will be the next big thing, and I think there's a lot of excitement about it. I think we're going to see it for a very long time."

To hear a full conversation with Google about the future of robotics, click here
Custom: Artificial intelligence (a Dienius. Singh was the only one one hour ending up the city, which time the world world world where they it possible possible possible possible possible possible possible any other other other regions and to the mainland, the enemy power that when the


Saved checkpoint for epoch 2


Training: 2571it [07:02,  6.09it/s]


Epoch 3/10 - Average Loss: 393.5711
Saved checkpoint for epoch 3


Training: 2571it [07:03,  6.07it/s]


Epoch 4/10 - Average Loss: 333.3250

Generating sample text:
Reference: Artificial intelligence can also help us understand the world without having to go through a much more complicated process of training.

"We know in this area that we can use artificial intelligence to learn from the human mind," says Gershman. "In our
Custom: Artificial intelligence is the reason for a series of different combinations systems:-t. that a higher resolution of that sentence spacing has to a given the form. .
XWY , the system has been carried over from other factors. The original design is used


Saved checkpoint for epoch 4


Training: 2571it [07:01,  6.09it/s]


Epoch 5/10 - Average Loss: 284.8573
Saved checkpoint for epoch 5


Training: 2571it [07:01,  6.10it/s]


Epoch 6/10 - Average Loss: 243.1686

Generating sample text:
Reference: Artificial intelligence is a complex subject. But the goal of Artificial Intelligence is to be able to answer questions that all humans would have to ask before we can truly understand the nature of our world.

Today, Artificial Intelligence is being used to solve problems like climate
Custom: Artificial intelligence and a very good deal with his family and father, have been a member of five children aged and married couples. and children three children , and a family of English children. on on a farm were made of a girl who just played with at the


Saved checkpoint for epoch 6


Training: 2571it [07:03,  6.08it/s]


Epoch 7/10 - Average Loss: 207.8062
Saved checkpoint for epoch 7


Training: 2571it [07:00,  6.11it/s]


Epoch 8/10 - Average Loss: 179.4301

Generating sample text:
Reference: Artificial intelligence, known as AI, is a fascinating subject that has become more controversial. The idea of artificial intelligence is so well known to many that it has also been suggested that scientists are looking into the possibility of AI. The issue is more complicated than that.
Custom: Artificial intelligence, which is in the midstententency- the debate is an "fridiazza in the Land of Light Justice (VNG ) ] ] ] ] ] ] ] ] ] shall be a "matter " ( [C ), )


Saved checkpoint for epoch 8


Training: 2571it [07:00,  6.12it/s]


Epoch 9/10 - Average Loss: 159.2151
Saved checkpoint for epoch 9


Training: 2571it [07:01,  6.11it/s]


Epoch 10/10 - Average Loss: 147.9481

Generating sample text:
Reference: Artificial intelligence is a major new field in tech, but the technology is not yet ready for commercial use.

The company behind the project, DeepMind, says it has developed a prototype for "an intelligent machine that can solve real-world problems in a
Custom: Artificial intelligence was a "wah-style " to describe a man of great authority to ' to the right-ing and the need for good . . I've Gog , his nephew, the "master , a good guy play-ing him into the


Saved checkpoint for epoch 10

Final generation comparison:
Reference: Artificial intelligence has a history of being used to achieve very particular goals, and we see it often in the search for a solution to a problem.

The main problem is not the technology itself - it is the amount of data that can be processed and stored
Custom: Artificial intelligence and the other major factors that do not apply the fundamental conclusions, it may be a useful consequence

In [None]:
############## # Code Block 1: Imports & Config ##############
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
from native_sparse_attention_pytorch import SparseAttention
import time
import psutil
import os


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"
SEQ_LEN = 128
BATCH_SIZE = 4
NUM_HEADS = 4
COMPRESS_RATIO = 0.25
WINDOW_SIZE = 64
NUM_EPOCHS = 20

# Native sparse attention configuration
SPARSE_CONFIG = {
    "dim": None,  # Will be set in the model
    "dim_head": 64,  # Dimension per head
    "heads": NUM_HEADS,
    "sliding_window_size": 2,  # Local attention window
    "compress_block_size": 4,  # Size of blocks to compress
    "selection_block_size": 4,  # Size of blocks to select from
    "num_selected_blocks": 2,  # Number of blocks to select
}

In [None]:
############## # Code Block 2: Sparse Attention Components ##############
class CompressedGlobalAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, compress_ratio):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.compress_ratio = compress_ratio

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)
        self.compression = nn.Linear(embed_dim, 1)
        self.expansion = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        keep_num = max(1, int(T * self.compress_ratio))

        # Token compression
        importance = self.compression(x).squeeze(-1)
        _, keep_idx = torch.topk(importance, k=keep_num, dim=-1)
        x_compressed = torch.gather(x, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))

        # Projections
        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = (
            self.Wk(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )
        V = (
            self.Wv(x_compressed)
            .view(B, keep_num, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )

        # Attention
        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        # Masking
        if attention_mask is not None:
            compressed_mask = torch.gather(attention_mask, 1, keep_idx)
            attn_scores = attn_scores.masked_fill(
                compressed_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        output = output.permute(0, 2, 1, 3).contiguous().view(B, T, D)

        output = self.expansion(output)  # Ensure output has correct embedding dimension
        output = output[:, : x.size(1), :]

        return output


class LocalWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.window_size = window_size

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)

    def create_window_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1
        return mask.unsqueeze(0).unsqueeze(1)

    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        window_mask = self.create_window_mask(T, x.device)

        Q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.Wk(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.Wv(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(window_mask == 0, -1e10)

        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(
                attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e10
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output.permute(0, 2, 1, 3).contiguous().view(B, T, D)


class HierarchicalSparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, compress_ratio):
        super().__init__()
        self.num_heads = num_heads  # 🔹 Store num_heads
        self.local_attn = LocalWindowAttention(embed_dim, num_heads, window_size)
        self.global_attn = CompressedGlobalAttention(
            embed_dim, num_heads, compress_ratio
        )
        self.gate = nn.Sequential(
            nn.Linear(
                embed_dim, num_heads * 2
            ),  # Ensure output is [batch, seq_len, num_heads * 2]
            nn.Softmax(dim=-1),
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attention_mask=None):
        # Get outputs from local and global attention modules.
        local_out = self.local_attn(x, attention_mask)  # Expected shape: (B, T, D)
        global_out = self.global_attn(x, attention_mask)  # Expected shape: (B, T, D)

        B, T, D = x.size()
        head_dim = D // self.num_heads  # Ensure D is divisible by num_heads

        # Compute gating weights.
        # self.gate should output a tensor of shape (B, T, num_heads*2)
        gate_out = self.gate(x)  # Shape: (B, T, num_heads*2)
        # Reshape to (B, T, num_heads, 2) where last dim holds [local_gate, global_gate]
        gates = gate_out.view(B, T, self.num_heads, 2)
        # Unbind the last dimension into two tensors
        local_gate = gates[..., 0]  # Shape: (B, T, num_heads)
        global_gate = gates[..., 1]  # Shape: (B, T, num_heads)

        # Reshape attention outputs to split heads: (B, T, num_heads, head_dim)
        local_out_heads = local_out.view(B, T, self.num_heads, head_dim)
        global_out_heads = global_out.view(B, T, self.num_heads, head_dim)

        # Ensure the gate tensors have an extra dimension for broadcasting: (B, T, num_heads, 1)
        local_gate = local_gate.unsqueeze(-1)
        global_gate = global_gate.unsqueeze(-1)

        # Element-wise multiply each head output by its corresponding gate weight
        combined = local_out_heads * local_gate + global_out_heads * global_gate
        # Reshape back to (B, T, D)
        combined = combined.view(B, T, D)
        return self.out_proj(combined)


In [None]:
############## # Code Block 3: Custom GPT-2 Model ##############
class SparseGPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.drop = nn.Dropout(config.embd_pdrop)

        # Create native sparse attention layer with correct parameters
        sparse_config = SPARSE_CONFIG.copy()
        sparse_config["dim"] = config.hidden_size
        sparse_config["compress_block_sliding_stride"] = 2
        self.sparse_attn = SparseAttention(**sparse_config)

        self.h = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "attn": self.sparse_attn,
                        "ln_1": nn.LayerNorm(config.hidden_size),
                        "mlp": nn.Sequential(
                            nn.Linear(config.hidden_size, 4 * config.hidden_size),
                            nn.GELU(),
                            nn.Linear(4 * config.hidden_size, config.hidden_size),
                        ),
                        "ln_2": nn.LayerNorm(config.hidden_size),
                    }
                )
                for _ in range(config.num_hidden_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.size()
        pos_ids = torch.arange(T, device=DEVICE).unsqueeze(0)

        x = self.drop(self.wte(input_ids) + self.wpe(pos_ids))

        attn_out = torch.zeros_like(x)

        for block in self.h:
            # Apply layer norm before attention
            normed_x = block["ln_1"](x)
            attention_result = block["attn"](normed_x)

            # Apply sparse attention and handle tuple output
            # attn_output = block["attn"](normed_x)
            if attention_result is not None:
                attn_out = attention_result
            # # If attn_output is a tuple, take the first element (the main output)
            # if isinstance(attn_output, tuple):
            #     attn_out = attn_output[0]
            # else:
            #     attn_out = attn_output

            # Apply mask after attention if provided
            if attention_mask is not None:
                attn_out = attn_out * attention_mask.unsqueeze(-1)

            x = x + attn_out
            x = x + block["mlp"](block["ln_2"](x))

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits  # Return only the logits, not a tuple


In [None]:
############## # Code Block 4: Training Setup ##############
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Reference model
ref_model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)
ref_model.eval()

# Custom model with native sparse attention
cust_config = GPT2Config.from_pretrained(MODEL_NAME)
cust_model = SparseGPT2(cust_config).to(DEVICE)

# Initialize with pretrained weights
pretrained_state_dict = ref_model.state_dict()
cust_model.load_state_dict(pretrained_state_dict, strict=False)

# Use a lower learning rate for fine-tuning
optimizer = torch.optim.AdamW(cust_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)


In [None]:
############## # Code Block 5: Training Loop ##############
class UsageTracker:
    def __init__(self):
        self.start_time = None
        self.start_cpu_time = None
        self.train_stats = {
            "epoch": [], "step": [],
            "wall_time_s": [], "cpu_time_s": [],
            "cpu_mem_mb": [], "gpu_alloc_mb": [], "gpu_reserved_mb": []
        }
        self.infer_stats = {
            "max_token_length": [],
            "wall_time_s": [], "cpu_time_s": [],
            "cpu_mem_mb": [], "gpu_alloc_mb": [], "gpu_reserved_mb": []
        }

    def start_tracking(self):
        self.start_time = time.time()
        self.start_cpu_time = time.process_time()

    def stop_tracking(self, is_training, epoch=None, step=None, max_token_length=None):
        if self.start_time is None or self.start_cpu_time is None:
            raise ValueError("Tracking not started. Call start_tracking() first.")

        end_time = time.time()
        end_cpu_time = time.process_time()

        proc = psutil.Process(os.getpid())
        cpu_mem = proc.memory_info().rss / (1024**2)
        if torch.cuda.is_available():
            gpu_alloc    = torch.cuda.memory_allocated(DEVICE) / (1024**2)
            gpu_reserved = torch.cuda.memory_reserved(DEVICE) / (1024**2)
        else:
            gpu_alloc = gpu_reserved = 0.0

        obj = self.train_stats if is_training else self.infer_stats
        obj["wall_time_s"].append(end_time - self.start_time)
        obj["cpu_time_s"].append(end_cpu_time - self.start_cpu_time)
        obj["cpu_mem_mb"].append(cpu_mem)
        obj["gpu_alloc_mb"].append(gpu_alloc)
        obj["gpu_reserved_mb"].append(gpu_reserved)
        if is_training:
            obj["epoch"].append(epoch + 1)
            obj["step"].append(step + 1)
        else:
            obj["max_token_length"].append(max_token_length)

        self.start_time = None
        self.start_cpu_time = None

    def plot(self):
        df_train = pd.DataFrame(self.train_stats)
        df_inf   = pd.DataFrame(self.infer_stats)

        # --- Training plots (line plots vs. step) ---
        for metric in ["wall_time_s", "cpu_time_s", "cpu_mem_mb", "gpu_alloc_mb", "gpu_reserved_mb"]:
            plt.figure(figsize=(8, 4))
            plt.plot(df_train["step"], df_train[metric])
            plt.xlabel("Training Step")
            plt.ylabel(metric.replace("_", " ").title())
            plt.title(f"{metric.replace('_', ' ').title()} During Training")
            plt.show()

        # # --- Inference plots (bar plots vs. call index) ---
        # for metric in ["wall_time_s", "cpu_time_s", "cpu_mem_mb", "gpu_alloc_mb", "gpu_reserved_mb"]:
        #     plt.figure(figsize=(8, 4))
        #     plt.bar(range(len(df_inf)), df_inf[metric])
        #     plt.xlabel("Inference Call #")
        #     plt.ylabel(metric.replace("_", " ").title())
        #     plt.title(f"{metric.replace('_', ' ').title()} Per Inference")
        #     plt.show()

        prompt_lengths = self.infer_stats["max_token_length"]

        # --- Inference plots (bar plots vs. token length) ---
        for metric in ["wall_time_s", "cpu_time_s", "cpu_mem_mb", "gpu_alloc_mb", "gpu_reserved_mb"]:
            plt.figure(figsize=(8, 4))
            plt.scatter(prompt_lengths, self.infer_stats[metric])
            plt.xlabel("Prompt Token Length")
            plt.ylabel(metric.replace("_", " ").title())
            plt.title(f"{metric.replace('_', ' ').title()} vs. Prompt Length")
            plt.show()

def kl_divergence_loss(logits_custom, logits_ref, mask, temperature):
    # 1) soften both distributions by T
    logp_c = F.log_softmax(logits_custom / temperature, dim=-1)
    p_r   = F.softmax(   logits_ref.detach() / temperature, dim=-1)

    # 2) per‑token KL
    kl = (p_r * (p_r.log() - logp_c)).sum(-1)  # (B, L)

    # 3) average over real tokens and re‑scale by T^2
    return (kl * mask).sum() / mask.sum() * (temperature * temperature)

def train_step(batch, epoch, step):
    tracker.start_tracking()

    inputs = batch.to(DEVICE)
    attention_mask = (inputs != tokenizer.pad_token_id).float()

    with torch.no_grad():
        ref_outputs = ref_model(inputs, attention_mask=attention_mask)
        if isinstance(ref_outputs, tuple):
            ref_logits = ref_outputs[0]
        else:
            ref_logits = ref_outputs.logits  # Extract logits from the output object

    cust_outputs = cust_model(inputs, attention_mask=attention_mask)
    if isinstance(cust_outputs, tuple):
        cust_logits = cust_outputs[0]
    else:
        cust_logits = cust_outputs

    # Use KL divergence loss with temperature
    temperature = 0.7
    loss = F.kl_div(
        F.log_softmax(cust_logits / temperature, dim=-1),
        F.softmax(ref_logits / temperature, dim=-1).detach(),
        reduction="sum"
    ) * (temperature**2)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(cust_model.parameters(), 1.0)
    optimizer.step()

    tracker.stop_tracking(is_training=True, epoch=epoch, step=step)

    return loss.item()


def train_epoch(loader, epoch):
    cust_model.train()
    total_loss = 0
    loss_vals = []
    for i, batch in tqdm(enumerate(loader), desc="Training"):
        loss = train_step(batch, epoch, i)
        total_loss += loss
        loss_vals.append(loss)
    return total_loss / len(loader), loss_vals

In [None]:
############## # Code Block 6: Generation & Evaluation ##############
def generate(model, tokenizer, prompt, max_length=50, temperature=0.7, top_k=50, device="cuda" if torch.cuda.is_available() else "cpu"):
    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    input_ids = input_ids[:1]  # Ensure we only have one batch dimension

    model.to(device)
    model.eval()

    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs

            # Get logits for the last token
            if logits.dim() == 3:  # Standard shape [batch, seq, vocab]
                next_token_logits = logits[:, -1, :] / temperature
            elif logits.dim() == 2:  # If somehow we got [batch*seq, vocab]
                # We only care about the last token's logits
                next_token_logits = logits[-1:, :] / temperature
            else:
                raise ValueError(f"Unexpected logits shape: {logits.shape}")

            # Apply top-k filtering
            top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)

            # Convert to probabilities
            probs = F.softmax(top_k_logits, dim=-1)

            # Sample next token index from top-k logits
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices.gather(1, next_token_idx)

            # Ensure next_token has shape [1, 1]
            next_token = next_token[-1:, :]  # Take only the last row if needed

            # Concatenate to input_ids
            input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [None]:
############## # Code Block 7: Dataset Preparation ##############
from datasets import load_dataset


class WikiDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.samples = []
        for text in texts:
            # Tokenize each text separately, without adding special tokens
            token_ids = tokenizer.encode(text, add_special_tokens=False)
            # Split token_ids into chunks of length seq_len
            for i in range(0, len(token_ids), seq_len):
                chunk = token_ids[i : i + seq_len]
                # Only add full chunks to avoid very short sequences
                if len(chunk) == seq_len:
                    self.samples.append(torch.tensor(chunk))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# Load a small subset (e.g., 1% of the train split) of WikiText data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100%]")
texts = dataset["text"]
wiki_dataset = WikiDataset(texts, tokenizer, SEQ_LEN)

# Create a DataLoader for training
train_loader = DataLoader(wiki_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
############## # Code Block 8: Training Execution ##############
tracker = UsageTracker()

CHECKPOINT_DIR = '/content/drive/MyDrive/attention_optimization/model_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

start_epoch = 6
ckpt_path = f"/content/drive/MyDrive/attention_optimization/model_checkpoints/ckpt_epoch_{start_epoch}.pth"
checkpoint = torch.load(ckpt_path, map_location=DEVICE)

cust_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']
loss_vals = checkpoint['loss_vals']
# val_loss_vals = checkpoint['val_loss_vals']
# all_alphas = checkpoint['all_alphas']

cust_model.train()

print(f"Resumed from epoch {start_epoch}")

for epoch in range(start_epoch + 1, NUM_EPOCHS):

    avg_loss, loss_vals = train_epoch(train_loader, epoch)
    scheduler.step()

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

    # Generate sample text after each epoch
    if (epoch + 1) % 2 == 0:  # Generate every 2 epochs
        print("\nGenerating sample text:")
        prompt = "Artificial intelligence"
        print("Reference:", generate(ref_model, tokenizer, prompt, temperature=0.7, top_k=50))
        print("Custom:", generate(cust_model, tokenizer, prompt, temperature=0.7, top_k=50))
        print("\n")

    # checkpoint model vars
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'epoch': epoch + 1, # next epoch to start from
            'model_state_dict': cust_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss_vals': loss_vals,
            # 'val_loss_vals': val_loss_vals,
            # 'all_alphas': all_alphas
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'ckpt_epoch_{epoch+1}.pth'))
        print(f"Saved checkpoint for epoch {epoch+1}")

# Final generation comparison
prompt = "Artificial intelligence"
print("\nFinal generation comparison:")
print("Reference:", generate(ref_model, tokenizer, prompt, temperature=0.7, top_k=50))
print("Custom:", generate(cust_model, tokenizer, prompt, temperature=0.7, top_k=50))


Resumed from epoch 6


Training: 2571it [09:09,  4.68it/s]


Epoch 8/20 - Average Loss: 207.8354

Generating sample text:
Reference: Artificial intelligence is a big step forward in this area of research.

You can connect to a human's brain after they're born, and even see a person's face from a distance, but your sense of what they're wearing and hearing will not be
Custom: Artificial intelligence as the first to be built in by the United States Environmental Protection Agency (BOTOTIIANOSANAR ) as the first major C EEAED ETSSTOSANDAOSOSIM in 2013 and 2007 BC BC-




Training: 2571it [09:05,  4.72it/s]


Epoch 9/20 - Average Loss: 188.6792


Training: 2571it [09:04,  4.72it/s]


Epoch 10/20 - Average Loss: 174.0925

Generating sample text:
Reference: Artificial intelligence can now take over the world, with the aim of bringing about the ultimate goal of global economic and political change.[1]

The research has also been performed by the UK Economic and Social Research Council (ESRC), which has been funded by
Custom: Artificial intelligence, you.

The and 2: the user-key, and & and MacLeodfish @ 3- @-:-Loids on the Space One. The Effect-1 (H ) )


-S ( OAS


Saved checkpoint for epoch 10


Training: 2571it [09:09,  4.68it/s]


Epoch 11/20 - Average Loss: 161.1676


Training: 2571it [09:08,  4.69it/s]


Epoch 12/20 - Average Loss: 149.6092

Generating sample text:
Reference: Artificial intelligence (AI) will be developed in earnest, as our technology's ability to understand what humans are saying is a fundamental skill, and will become more and more important, as AI becomes more and more important to the business world. This will lead to the
Custom: Artificial intelligence report was also investigated as evidence suggests that this is a "clear-dorf model with another three types-linked series. cit. The New York Times reporter had "found a high-tonstoneton'sululine. (17 ), Valley




Training: 2571it [09:03,  4.73it/s]


Epoch 13/20 - Average Loss: 139.0573


Training: 2571it [08:58,  4.77it/s]


Epoch 14/20 - Average Loss: 129.5528

Generating sample text:
Reference: Artificial intelligence (AI) will be the next step in this conversation. The next step is to see if it can handle an infinite number of tasks for which it has no real experience. The next step is to see if AI can create a better or worse AI
Custom: Artificial intelligence policy. the public press conference to explain that the media group have been looking at the government'sians, and to those who are actually found in that they are more than in the case cases and in the media, and in the media coverage program news




Training: 2571it [09:09,  4.68it/s]


Epoch 15/20 - Average Loss: 120.7946
Saved checkpoint for epoch 15


Training: 2571it [09:08,  4.69it/s]


Epoch 16/20 - Average Loss: 112.8696

Generating sample text:
Reference: Artificial intelligence has been used to train animals for more than a century, and most of those trained to do so have been trained to be intelligent. It has been known for some time that humans are "super intelligent" and that many other primates have been created "
Custom: Artificial intelligence needs a very large percentage of the food. or other people.





- #6 CXTIE.I is used in as a whole, you, it, you . [ 1 , if a whole offer way




Training: 2571it [09:07,  4.69it/s]


Epoch 17/20 - Average Loss: 105.5966


Training: 2571it [09:07,  4.69it/s]


Epoch 18/20 - Average Loss: 99.0282

Generating sample text:
Reference: Artificial intelligence (AI) is the future of computing, and it will revolutionize the way we think about the world. Artificial intelligence is a major innovation in today's technological world.

AI is also a way for our society to live and work without having
Custom: Artificial intelligence-0-1A-1, @980V/ (AT,@ or @,@ 16491525-/-W,@ 4,@ 3,@ 3888M /)-100@2 @-,/ )




Training: 2571it [09:07,  4.69it/s]


Epoch 19/20 - Average Loss: 93.2158


Training: 2571it [09:07,  4.70it/s]


Epoch 20/20 - Average Loss: 88.0451

Generating sample text:
Reference: Artificial intelligence. The process is described by a group of researchers at the University of California, Berkeley. The team aims to use machine learning to better understand how robots will interact with humans, with the aim of developing tools that enable social networks to communicate with each other
Custom: Artificial intelligence service station CCTV. A. (J. . CPT, M. D. (director , M6HOXC/GECCTRM and PEGX2XXXXC.TUYITMC.


Saved checkpoint for epoch 20

Final generation comparison:
Reference: Artificial intelligence is already one of the top 10 technologies in the business, but the number is growing every year.

The company has now confirmed that it is the first company to offer the full range of artificial intelligence and artificial intelligence tools.

The firm
Custom: Artificial intelligence agents have long been able-tearicke , and may exist. it'seticizedization , where there is a substantial number of

# Conclusion

We have demonstrated:
1) Loading a reference GPT-2 model from Hugging Face.
2) Creating a custom GPT-2-like model with a simplified "last-5-tokens" attention mechanism.
3) Setting up a dataset and training loop that optimizes the custom model to match the reference distribution via KL-divergence.
4) Showed a simple comparison of generated text from both models.

This notebook is purely for demonstration and educational purposes, and many improvements could be made:
- More elaborate data loading
- Proper scheduling, regularization
- Additional GPT-2 intricacies (like caching attention states, etc.)
- More advanced generation strategies (beam search, top-k, top-p, etc.)

But this entire workflow shows how one could begin to experiment with custom attention
mechanisms and align them to a known distribution via KL divergence.