# ðŸ”„ Transformer-XL

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/transformer_architectures/05_transformer_xl/demo.ipynb)

![Architecture](architecture.png)

### Key Innovation
- **Segment-Level Recurrence**: Cache hidden states from previous segments
- **Relative Positional Encoding**: Enable longer dependencies
- **Extended Context**: Break fixed-length limitation

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Problem: Fixed Context in Standard Transformer

In [None]:
# Standard transformer: Each segment is processed independently
# This causes context fragmentation!

def visualize_context_problem():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Standard Transformer
    ax = axes[0]
    segments = ['Seg 1', 'Seg 2', 'Seg 3', 'Seg 4']
    for i, seg in enumerate(segments):
        rect = plt.Rectangle((i*1.1, 0), 1, 0.5, color=plt.cm.Blues(0.3+i*0.15))
        ax.add_patch(rect)
        ax.text(i*1.1+0.5, 0.25, seg, ha='center', va='center', fontsize=10)
        # No arrows between segments
    ax.set_xlim(-0.2, 5)
    ax.set_ylim(-0.5, 1)
    ax.set_title('Standard Transformer: No Context Across Segments', fontsize=12)
    ax.axis('off')
    ax.text(2.2, -0.3, 'X No memory between segments', fontsize=11, color='red', ha='center')
    
    # Transformer-XL
    ax = axes[1]
    for i, seg in enumerate(segments):
        rect = plt.Rectangle((i*1.1, 0), 1, 0.5, color=plt.cm.Greens(0.3+i*0.15))
        ax.add_patch(rect)
        ax.text(i*1.1+0.5, 0.25, seg, ha='center', va='center', fontsize=10)
        if i > 0:
            ax.annotate('', xy=(i*1.1+0.1, 0.5), xytext=((i-1)*1.1+0.9, 0.5),
                       arrowprops=dict(arrowstyle='->', color='green', lw=2))
    ax.set_xlim(-0.2, 5)
    ax.set_ylim(-0.5, 1)
    ax.set_title('Transformer-XL: Recurrent Memory', fontsize=12)
    ax.axis('off')
    ax.text(2.2, -0.3, 'âœ“ Hidden states cached and reused', fontsize=11, color='green', ha='center')
    
    plt.tight_layout()
    plt.show()

visualize_context_problem()

## Relative Positional Encoding

In [None]:
class RelativePositionalEncoding(nn.Module):
    """Relative positional encoding for Transformer-XL."""
    def __init__(self, d_model, max_len=1024):
        super().__init__()
        self.d_model = d_model
        
        # Create sinusoidal positional encodings for relative positions
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, seq_len, mem_len):
        # Returns encodings for relative positions
        total_len = seq_len + mem_len
        return self.pe[:total_len]

# Visualize relative vs absolute positioning
def compare_positions():
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    
    # Absolute positions (problematic for long sequences)
    ax = axes[0]
    positions = list(range(8))
    ax.barh(positions, [1]*8, color='coral')
    for i, p in enumerate(positions):
        ax.text(0.5, i, f'pos={p}', va='center', ha='center', fontsize=10)
    ax.set_title('Absolute Position: Fixed for each location', fontsize=11)
    ax.set_xlabel('Token Index')
    ax.invert_yaxis()
    
    # Relative positions
    ax = axes[1]
    query_pos = 4
    rel_positions = [query_pos - i for i in range(8)]
    colors = plt.cm.RdYlGn([0.5 + 0.05*p for p in rel_positions])
    ax.barh(range(8), [1]*8, color=colors)
    for i, p in enumerate(rel_positions):
        ax.text(0.5, i, f'rel={p}', va='center', ha='center', fontsize=10)
    ax.axhline(y=query_pos, color='blue', linestyle='--', linewidth=2, label='Query')
    ax.set_title(f'Relative Position: Distance from query (pos={query_pos})', fontsize=11)
    ax.set_xlabel('Token Index')
    ax.legend()
    ax.invert_yaxis()
    
    plt.tight_layout()
    plt.show()

compare_positions()
print('Relative positions enable: Same pattern at any location!')

## Transformer-XL Architecture

In [None]:
class TransformerXLAttention(nn.Module):
    """Multi-head attention with segment-level recurrence."""
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.scale = self.d_k ** -0.5
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_kv = nn.Linear(d_model, 2 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        
        # Relative position biases (from Transformer-XL paper)
        self.u = nn.Parameter(torch.randn(n_heads, self.d_k) * 0.02)
        self.v = nn.Parameter(torch.randn(n_heads, self.d_k) * 0.02)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, h, memory=None, pos_embed=None, mask=None):
        B, T, C = h.shape
        
        # Concatenate memory (cached hidden states from previous segment)
        if memory is not None:
            cat = torch.cat([memory, h], dim=1)  # (B, M+T, C)
        else:
            cat = h
        
        M = cat.size(1) - T  # Memory length
        
        # Q from current segment, K/V from memory + current
        Q = self.W_q(h).view(B, T, self.n_heads, self.d_k)  # (B, T, H, D)
        kv = self.W_kv(cat).view(B, -1, 2, self.n_heads, self.d_k)  # (B, M+T, 2, H, D)
        K, V = kv[:, :, 0], kv[:, :, 1]  # (B, M+T, H, D)
        
        # Content-based attention
        Q_u = Q + self.u  # Add content bias
        content_score = torch.einsum('bthd,bshd->btsh', Q_u, K)  # (B, T, M+T, H)
        
        # Position-based attention (simplified)
        if pos_embed is not None:
            Q_v = Q + self.v  # Add position bias
            pos_score = torch.einsum('bthd,sd->btsh', Q_v, pos_embed[:M+T])  # Simplified
            attn = (content_score + pos_score.unsqueeze(-1)) * self.scale
        else:
            attn = content_score * self.scale
        
        # Apply causal mask
        if mask is not None:
            attn = attn.masked_fill(mask[:T, :M+T].unsqueeze(0).unsqueeze(-1) == 0, float('-inf'))
        
        attn = F.softmax(attn, dim=2)
        attn = self.dropout(attn)
        
        out = torch.einsum('btsh,bshd->bthd', attn, V).reshape(B, T, C)
        return self.W_o(out)

class TransformerXLBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.attn = TransformerXLAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, memory=None, pos_embed=None, mask=None):
        x = x + self.attn(self.norm1(x), memory, pos_embed, mask)
        x = x + self.ff(self.norm2(x))
        return x

class TransformerXL(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, mem_len=64, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.mem_len = mem_len
        self.n_layers = n_layers
        
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = RelativePositionalEncoding(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([TransformerXLBlock(d_model, n_heads, dropout=dropout) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x, memories=None):
        B, T = x.shape
        
        h = self.dropout(self.embed(x))
        
        if memories is None:
            memories = [None] * self.n_layers
        
        new_memories = []
        mem_len = memories[0].size(1) if memories[0] is not None else 0
        pos = self.pos_embed.pe[:T + mem_len]
        
        # Create causal mask
        mask = torch.tril(torch.ones(T + mem_len, T + mem_len, device=x.device))
        
        for i, layer in enumerate(self.layers):
            # Cache current hidden state for next segment
            new_memories.append(h.detach()[:, -self.mem_len:] if self.mem_len > 0 else None)
            h = layer(h, memories[i], pos, mask)
        
        h = self.norm(h)
        return self.head(h), new_memories

model = TransformerXL(vocab_size=1000, d_model=64, n_heads=4, n_layers=3, mem_len=32).to(device)
print(f'Transformer-XL Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Training with Segment-Level Recurrence

In [None]:
# Create a simple sequence dataset
text = 'the quick brown fox jumps over the lazy dog ' * 500
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}

data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)
print(f'Vocab: {len(chars)}, Data: {len(data)}')

# Training with memory
model = TransformerXL(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=2, mem_len=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

seg_len = 32
n_steps = 300
losses = []

print('\nTraining Transformer-XL with recurrent memory...')
for step in range(n_steps):
    # Random starting point
    start = torch.randint(0, len(data) - seg_len * 3, (1,)).item()
    
    total_loss = 0
    memories = None
    
    # Process multiple segments with memory
    for seg_idx in range(3):
        seg_start = start + seg_idx * seg_len
        x = data[seg_start:seg_start + seg_len].unsqueeze(0).to(device)
        y = data[seg_start + 1:seg_start + seg_len + 1].unsqueeze(0).to(device)
        
        logits, memories = model(x, memories)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
        total_loss += loss
        
        # Detach memories to prevent backprop through time (BPTT)
        memories = [m.detach() if m is not None else None for m in memories]
    
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    
    losses.append(total_loss.item() / 3)
    
    if (step + 1) % 50 == 0:
        print(f'Step {step+1}: Loss = {losses[-1]:.4f}')

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Transformer-XL Training with Recurrent Memory')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Compare: Standard vs XL context
def visualize_effective_context():
    fig, ax = plt.subplots(figsize=(12, 4))
    
    # Standard Transformer
    seg_len = 32
    n_segs = 4
    
    # Colors for each segment
    colors = plt.cm.tab10(range(n_segs))
    
    x_pos = 0
    for i in range(n_segs):
        rect = plt.Rectangle((x_pos, 0.5), seg_len, 0.4, color=colors[i], alpha=0.6)
        ax.add_patch(rect)
        ax.text(x_pos + seg_len/2, 0.7, f'Seg {i+1}', ha='center', va='center')
        
        # Standard transformer context (only current segment)
        ax.arrow(x_pos + seg_len/2, 0.3, 0, -0.1, head_width=2, head_length=0.02, fc='gray', ec='gray')
        ax.text(x_pos + seg_len/2, 0.1, f'Context: {seg_len}', ha='center', fontsize=8)
        
        x_pos += seg_len + 5
    
    ax.text(65, 0.1, 'Standard Transformer', fontsize=10, color='gray')
    
    # Transformer-XL
    x_pos = 0
    for i in range(n_segs):
        rect = plt.Rectangle((x_pos, -0.5), seg_len, 0.4, color=colors[i], alpha=0.6)
        ax.add_patch(rect)
        ax.text(x_pos + seg_len/2, -0.3, f'Seg {i+1}', ha='center', va='center')
        
        # XL context (current + memory from previous)
        mem_ctx = min(i, 2) * seg_len + seg_len  # Growing context
        ax.arrow(x_pos + seg_len/2, -0.7, 0, -0.1, head_width=2, head_length=0.02, fc='green', ec='green')
        ax.text(x_pos + seg_len/2, -0.9, f'Context: {mem_ctx}', ha='center', fontsize=8, color='green')
        
        x_pos += seg_len + 5
    
    ax.text(65, -0.9, 'Transformer-XL', fontsize=10, color='green')
    
    ax.set_xlim(-5, 150)
    ax.set_ylim(-1.1, 1.1)
    ax.set_title('Effective Context Length Comparison')
    ax.axis('off')
    plt.show()

visualize_effective_context()

print('\nðŸŽ¯ Key Takeaways:')
print('1. Transformer-XL caches hidden states from previous segments')
print('2. Relative positional encoding enables unbounded context')
print('3. Effective context grows with each segment')
print('4. Essential for long document understanding')