# Forgetting Transformer (FoX) Implementation

**Paper**: "Forgetting Transformer: Softmax Attention with a Forget Gate"  
**Authors**: Zhixuan Lin, Evgenii Nikishin, Xu Owen He, Aaron Courville (2025)

---

## Mathematical Foundation

### Standard Attention
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot V$$

### Forgetting Attention (Novel)
$$\text{ForgetAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} \odot F\right) \cdot V$$

Where the **Forget Gate** is:
$$F_{ij} = \sigma(W_q \cdot q_i + W_k \cdot k_j + b)$$

---

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Our implementations
from src.models.standard_attention import ScaledDotProductAttention, MultiHeadAttention
from src.models.forgetting_attention import ForgetGate, ForgettingAttention, MultiHeadForgettingAttention
from src.models.transformer_blocks import LanguageModel

# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Style
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 4)
plt.rcParams['font.size'] = 11

## 1. Understanding the Forget Gate

The forget gate is the core innovation. Let's visualize how it works.

In [None]:
# Create sample inputs
batch_size = 1
seq_len = 20
d_k = 16

torch.manual_seed(42)
query = torch.randn(batch_size, seq_len, d_k, device=device)
key = torch.randn(batch_size, seq_len, d_k, device=device)
value = torch.randn(batch_size, seq_len, d_k, device=device)

# Initialize forget gate
forget_gate = ForgetGate(d_k=d_k).to(device)

# Compute forget gate values
gate_values = forget_gate(query, key)

print(f"Forget gate shape: {gate_values.shape}")
print(f"Gate value range: [{gate_values.min().item():.4f}, {gate_values.max().item():.4f}]")
print(f"Gate mean: {gate_values.mean().item():.4f}")

In [None]:
# Visualize the forget gate
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot forget gate values
im1 = axes[0].imshow(gate_values[0].detach().cpu().numpy(), cmap='RdYlBu_r', vmin=0, vmax=1)
axes[0].set_title('Forget Gate Values\n(higher = remember more)', fontsize=12)
axes[0].set_xlabel('Key position (j)')
axes[0].set_ylabel('Query position (i)')
plt.colorbar(im1, ax=axes[0])

# Compute standard attention scores (without forget gate)
std_attention = ScaledDotProductAttention().to(device)
_, std_weights = std_attention(query, key, value, return_attention_weights=True)

im2 = axes[1].imshow(std_weights[0].detach().cpu().numpy(), cmap='Blues')
axes[1].set_title('Standard Attention Weights', fontsize=12)
axes[1].set_xlabel('Key position')
axes[1].set_ylabel('Query position')
plt.colorbar(im2, ax=axes[1])

# Compute forgetting attention
fox_attention = ForgettingAttention(d_k=d_k).to(device)
# Copy forget gate weights
fox_attention.forget_gate.load_state_dict(forget_gate.state_dict())

_, fox_weights, _ = fox_attention(query, key, value, return_attention_weights=True)

im3 = axes[2].imshow(fox_weights[0].detach().cpu().numpy(), cmap='Blues')
axes[2].set_title('Forgetting Attention Weights\n(with forget gate)', fontsize=12)
axes[2].set_xlabel('Key position')
axes[2].set_ylabel('Query position')
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.savefig('../results/attention_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nðŸ“Š Saved: results/attention_comparison.png")

## 2. Comparing Model Architectures

In [None]:
# Create both models
vocab_size = 5000
d_model = 128
num_heads = 4
num_layers = 4

std_model = LanguageModel(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    use_forgetting=False
).to(device)

fox_model = LanguageModel(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    use_forgetting=True
).to(device)

print("Model Comparison:")
print("=" * 50)
print(f"{'Model':<25} {'Parameters':>20}")
print("-" * 50)
print(f"{'Standard Transformer':<25} {std_model.count_parameters():>20,}")
print(f"{'Forgetting Transformer':<25} {fox_model.count_parameters():>20,}")
print("-" * 50)
diff = fox_model.count_parameters() - std_model.count_parameters()
print(f"{'Difference (forget gates)':<25} {diff:>20,}")
print(f"{'Overhead':<25} {100*diff/std_model.count_parameters():>19.2f}%")

## 3. Training Experiment

Let's train both models on a simple task and compare their learning curves.

In [None]:
# Generate synthetic data (copy task)
def generate_copy_data(batch_size, seq_len, vocab_size):
    """Generate data for copy task: model must learn to repeat input."""
    # Random tokens (excluding 0 for padding)
    tokens = torch.randint(1, vocab_size, (batch_size, seq_len // 2))
    # Target is the same sequence (for next-token prediction)
    input_ids = tokens.repeat(1, 2)
    return input_ids.to(device)

# Training function
def train_model(model, num_steps=500, batch_size=32, seq_len=64, lr=1e-3):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    
    losses = []
    pbar = tqdm(range(num_steps), desc="Training")
    
    for step in pbar:
        # Generate batch
        input_ids = generate_copy_data(batch_size, seq_len, vocab_size)
        
        # Forward pass
        logits, _ = model(input_ids)
        
        # Compute loss (next token prediction)
        loss = F.cross_entropy(
            logits[:, :-1].reshape(-1, vocab_size),
            input_ids[:, 1:].reshape(-1)
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        losses.append(loss.item())
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return losses

In [None]:
# Train both models
print("Training Standard Transformer...")
std_losses = train_model(std_model, num_steps=300)

print("\nTraining Forgetting Transformer...")
fox_losses = train_model(fox_model, num_steps=300)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw loss curves
axes[0].plot(std_losses, label='Standard Transformer', alpha=0.7)
axes[0].plot(fox_losses, label='Forgetting Transformer', alpha=0.7)
axes[0].set_xlabel('Training Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Smoothed loss curves
window = 20
std_smooth = np.convolve(std_losses, np.ones(window)/window, mode='valid')
fox_smooth = np.convolve(fox_losses, np.ones(window)/window, mode='valid')

axes[1].plot(std_smooth, label='Standard Transformer', linewidth=2)
axes[1].plot(fox_smooth, label='Forgetting Transformer', linewidth=2)
axes[1].set_xlabel('Training Step')
axes[1].set_ylabel('Loss (smoothed)')
axes[1].set_title(f'Smoothed Training Loss (window={window})')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nðŸ“Š Results:")
print(f"  Standard final loss: {np.mean(std_losses[-50:]):.4f}")
print(f"  Forgetting final loss: {np.mean(fox_losses[-50:]):.4f}")

## 4. Analyzing Forget Gate Behavior

Let's see how the forget gate values evolve after training.

In [None]:
# Get forget gate values from trained model
fox_model.eval()

# Generate a sample
sample_input = generate_copy_data(1, 32, vocab_size)

# Get attention weights and forget gates
with torch.no_grad():
    # Access the first layer's attention
    first_layer = fox_model.encoder.layers[0]
    
    # Get embeddings
    x = fox_model.embedding(sample_input) * (fox_model.d_model ** 0.5)
    positions = torch.arange(sample_input.size(1), device=device).unsqueeze(0)
    x = x + fox_model.pos_embedding(positions)
    
    # Forward through first layer with outputs
    normed = first_layer.norm1(x)
    _, attn_weights, forget_gate = first_layer.attention(
        normed, normed, normed,
        return_attention_weights=True,
        return_forget_gate=True
    )

In [None]:
# Visualize forget gate patterns
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for head_idx in range(min(4, fox_model.encoder.layers[0].attention.num_heads)):
    ax = axes[head_idx // 2, head_idx % 2]
    
    gate_vals = forget_gate[0, head_idx].cpu().numpy()
    im = ax.imshow(gate_vals, cmap='RdYlBu_r', vmin=0, vmax=1)
    ax.set_title(f'Head {head_idx + 1} Forget Gate (after training)', fontsize=12)
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')
    plt.colorbar(im, ax=ax)

plt.suptitle('Forget Gate Patterns Across Attention Heads', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('../results/forget_gate_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Analyze forget gate statistics
print("\nðŸ“Š Forget Gate Statistics (after training):")
print("=" * 50)

for head_idx in range(forget_gate.size(1)):
    gate_vals = forget_gate[0, head_idx]
    print(f"\nHead {head_idx + 1}:")
    print(f"  Mean:     {gate_vals.mean().item():.4f}")
    print(f"  Std:      {gate_vals.std().item():.4f}")
    print(f"  Min:      {gate_vals.min().item():.4f}")
    print(f"  Max:      {gate_vals.max().item():.4f}")
    print(f"  Low (<0.3): {(gate_vals < 0.3).sum().item()} / {gate_vals.numel()}")
    print(f"  High (>0.7): {(gate_vals > 0.7).sum().item()} / {gate_vals.numel()}")

## 5. Critical Evaluation Summary

### Observations:
1. **Parameter Overhead**: Minimal (~0.1% extra parameters for forget gates)
2. **Training Behavior**: [Observe from plots above]
3. **Forget Gate Patterns**: Heads learn different forgetting strategies

### Paper Claims to Verify:
- [ ] O(1) memory complexity (requires recurrent formulation)
- [ ] Better length extrapolation
- [ ] No positional embeddings needed

### Next Steps:
1. Test on longer sequences
2. Measure actual memory usage
3. Compare on real language modeling benchmarks

In [None]:
# Save models
torch.save(std_model.state_dict(), '../results/std_model.pt')
torch.save(fox_model.state_dict(), '../results/fox_model.pt')
print("Models saved to results/")