<a href="https://colab.research.google.com/github/jeroaranda/naturalattention/blob/main/Atenci%C3%B3n.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb transformers datasets torch tqdm

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from datasets import load_dataset
import wandb
import numpy as np
from tqdm import tqdm
from copy import deepcopy
import math

In [8]:
## efficiency
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from datasets import load_dataset
import wandb
import numpy as np
from tqdm import tqdm
import math
import os
from datetime import datetime
class WikiTextDataset(Dataset):
    def __init__(self, tokenizer, split='train', max_length=64):
        self.tokenizer = tokenizer
        self.max_length = max_length

        print(f"Loading WikiText-2 dataset ({split} split)...")
        # Load only 300 examples
        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split=f'{split}[:300]')

        # Process in chunks
        self.chunks = []
        current_chunk = []
        current_length = 0

        for text in tqdm(dataset['text'], desc="Processing text"):
            if not text.strip():
                continue

            # Tokenize each text separately
            tokens = tokenizer.encode(text, truncation=True, max_length=max_length)

            if current_length + len(tokens) > max_length:
                if current_chunk:
                    self.chunks.append(current_chunk)
                current_chunk = tokens
                current_length = len(tokens)
            else:
                current_chunk.extend(tokens)
                current_length += len(tokens)

            if current_length >= max_length:
                self.chunks.append(current_chunk[:max_length])
                current_chunk = []
                current_length = 0

        if current_chunk:
            self.chunks.append(current_chunk)

        print(f"Created {len(self.chunks)} chunks of maximum length {max_length}")

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        chunk = self.chunks[idx]
        if len(chunk) < self.max_length + 1:
            chunk = chunk + [self.tokenizer.pad_token_id] * (self.max_length + 1 - len(chunk))
        return torch.tensor(chunk[:self.max_length + 1])
def train_epoch(model, optimizer, train_loader, device, model_type, global_step):
    model.train()
    total_loss = 0.0
    total_perplexity = 0.0
    num_batches = 0

    for batch in tqdm(train_loader, desc=f"Training {model_type} model"):
        try:
            input_ids = batch[:, :-1].to(device)
            labels = batch[:, 1:].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, labels=labels)
            loss = outputs.loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            perplexity = torch.exp(loss).item()

            # Log every batch
            wandb.log({
                "step": global_step,
                f"{model_type}/batch/loss": loss.item(),
                f"{model_type}/batch/perplexity": perplexity,
                f"{model_type}_loss": loss.item(),  # Additional metrics for direct comparison
                "global_step": global_step
            })

            total_loss += loss.item()
            total_perplexity += perplexity
            num_batches += 1
            global_step += 1

        except Exception as e:
            print(f"Error in batch: {str(e)}")
            continue

    metrics = {
        'loss': total_loss / num_batches,
        'perplexity': total_perplexity / num_batches,
    }

    return metrics, global_step


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class NaturalAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.hidden_size = config.n_embd
        self.head_dim = self.hidden_size // self.n_head

        # Linear projections
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.out_proj = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, hidden_states, layer_past=None, attention_mask=None, head_mask=None,
                use_cache=False, output_attentions=False):
        batch_size, seq_length, _ = hidden_states.size()

        # Project Q, K, V
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_length, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_length, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_length, self.n_head, self.head_dim).transpose(1, 2)

        # Handle layer past if provided
        if layer_past is not None:
            past_key, past_value = layer_past
            k = torch.cat((past_key, k), dim=-2)
            v = torch.cat((past_value, v), dim=-2)

        present = (k, v) if use_cache else None

        # Compute raw attention energies
        attention_energies = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_energies = attention_energies + attention_mask

        # Store raw energies for optimization
        self.last_attention_energies = attention_energies.detach()

        # Regular attention computation
        attention_probs = F.softmax(attention_energies, dim=-1)

        # Apply head mask if provided
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, v)

        # Reshape output
        context_layer = context_layer.transpose(1, 2).contiguous()
        context_layer = context_layer.view(batch_size, seq_length, self.hidden_size)

        # Project output
        output = self.out_proj(context_layer)

        outputs = (output, present)
        if output_attentions:
            outputs += (attention_probs,)

        return outputs

class GPT2NaturalAttentionBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = NaturalAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd)
        )

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attentions tuples have varying length
        outputs = attn_outputs[1:]

        hidden_states = residual + attn_output

        # Store attention energies in parameters for optimizer
        for p in self.parameters():
            p._attention_energies = self.attn.last_attention_energies

        # Feed-forward block
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_output = self.mlp(hidden_states)
        hidden_states = residual + feed_forward_output

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs

class AttentionInformedOptimizer(torch.optim.AdamW):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, energy_scale=0.1):
        super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.energy_scale = energy_scale

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Get attention energies if available
                if hasattr(p, '_attention_energies'):
                    # Scale gradient based on attention energies
                    energy_factor = torch.tanh(p._attention_energies.abs().mean() * self.energy_scale)
                    p.grad.data *= (1.0 + energy_factor)

        # Perform regular Adam update
        return super().step(closure)

In [22]:
def train_both_models(standard_model, natural_atention_model,
                     standard_optimizer, natural_atention_optimizer,
                     train_loader, device, global_step):
    standard_model.train()
    natural_atention_model.train()

    total_standard_loss = 0.0
    total_natural_atention_loss = 0.0
    num_batches = 0

    for batch in tqdm(train_loader, desc=f"Training both models"):
        try:
            # Debug prints for batch shape


            input_ids = batch[:, :-1].to(device)
            labels = batch[:, 1:].to(device)



            # Train standard model
            standard_optimizer.zero_grad()
            standard_outputs = standard_model(input_ids, labels=labels)
            standard_loss = standard_outputs.loss

            standard_loss.backward()
            torch.nn.utils.clip_grad_norm_(standard_model.parameters(), 1.0)
            standard_optimizer.step()

            # Train natural_atention model
            natural_atention_optimizer.zero_grad()
            natural_atention_outputs = natural_atention_model(input_ids, labels=labels)
            natural_atention_loss = natural_atention_outputs.loss

            natural_atention_loss.backward()
            torch.nn.utils.clip_grad_norm_(natural_atention_model.parameters(), 1.0)
            natural_atention_optimizer.step()

            # Calculate perplexities
            standard_perplexity = torch.exp(standard_loss).item()
            natural_atention_perplexity = torch.exp(natural_atention_loss).item()

            # Log with same step for both models
            wandb.log({
                "step": global_step,
                "standard/batch/loss": standard_loss.item(),
                "standard/batch/perplexity": standard_perplexity,
                "natural_atention/batch/loss": natural_atention_loss.item(),
                "natural_atention/batch/perplexity": natural_atention_perplexity,
                "global_step": global_step
            })

            total_standard_loss += standard_loss.item()
            total_natural_atention_loss += natural_atention_loss.item()
            num_batches += 1
            global_step += 1

        except Exception as e:
            print(f"Error in batch: {str(e)}")
            print(f"Full error traceback:")
            import traceback
            traceback.print_exc()
            continue

    metrics = {
        'standard_loss': total_standard_loss / num_batches,
        'natural_atention_loss': total_natural_atention_loss / num_batches,
    }

    return metrics, global_step

def train_models(config_dict):

    run = wandb.init(project="enhancednaturalattention", config=config_dict,reinit=True)

    # Define metrics for better visualization
    wandb.define_metric("step")
    wandb.define_metric("global_step")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    model_config = GPT2Config(
        vocab_size=tokenizer.vocab_size,
        n_positions=config_dict['max_length'],
        n_ctx=config_dict['max_length'],
        n_embd=config_dict['n_embd'],
        n_layer=config_dict['n_layer'],
        n_head=config_dict['n_head']
    )

    standard_model = GPT2LMHeadModel(model_config).to(device)
    # Initialize model with natural attention
    natural_atention_model = GPT2LMHeadModel(model_config)
    for i, block in enumerate(natural_atention_model.transformer.h):
        natural_atention_model.transformer.h[i] = GPT2NaturalAttentionBlock(model_config)

    # Use the attention-informed optimizer
    natural_atention_optimizer = AttentionInformedOptimizer(
        natural_atention_model.parameters(),
        lr=config_dict['learning_rate'],
        energy_scale=0.1  # Adjust this to control attention influence
    )

    train_dataset = WikiTextDataset(
        tokenizer,
        split='train',
        max_length=config_dict['max_length']
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config_dict['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=False
    )

    standard_optimizer = torch.optim.AdamW(
        standard_model.parameters(),
        lr=config_dict['learning_rate'],
        weight_decay=0.01
    )

    checkpoint_dir = os.path.join("checkpoints", run.id)
    os.makedirs(checkpoint_dir, exist_ok=True)

    global_step = 0
    for epoch in range(config_dict['epochs']):
        print(f"\nEpoch {epoch+1}/{config_dict['epochs']}")

        metrics, global_step = train_both_models(
            standard_model, natural_atention_model,
            standard_optimizer, natural_atention_optimizer,
            train_loader, device, global_step
        )

        if (epoch + 1) % config_dict['save_every'] == 0:
            for model_type, model, optimizer in [
                ("standard", standard_model, standard_optimizer),
                ("natural_atention", natural_atention_model, natural_atention_optimizer)
            ]:
                checkpoint_path = os.path.join(
                    checkpoint_dir,
                    f"gpt2_{model_type}_epoch_{epoch+1}.pt"
                )
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'metrics': metrics,
                }, checkpoint_path)

    print("Training completed successfully!")
    run.finish()
    return standard_model, natural_atention_model

# Configuration with more epochs
config_dict = {
    'max_length': 32,
    'batch_size': 4,
    'n_embd': 64,
    'n_layer': 2,
    'n_head': 2,
    'learning_rate': 1e-3,
    'epochs': 10,     # Increased epochs
    'save_every': 2
}

In [23]:
standard_model, naturall_model = train_models(config_dict)

Using device: cpu
Loading WikiText-2 dataset (train split)...


Processing text: 100%|██████████| 300/300 [00:00<00:00, 1350.72it/s]


Created 181 chunks of maximum length 32

Epoch 1/10


Training both models: 100%|██████████| 46/46 [00:16<00:00,  2.71it/s]



Epoch 2/10


Training both models: 100%|██████████| 46/46 [00:17<00:00,  2.69it/s]



Epoch 3/10


Training both models: 100%|██████████| 46/46 [00:17<00:00,  2.69it/s]



Epoch 4/10


Training both models: 100%|██████████| 46/46 [00:18<00:00,  2.44it/s]



Epoch 5/10


Training both models: 100%|██████████| 46/46 [00:16<00:00,  2.73it/s]



Epoch 6/10


Training both models: 100%|██████████| 46/46 [00:16<00:00,  2.72it/s]



Epoch 7/10


Training both models: 100%|██████████| 46/46 [00:17<00:00,  2.68it/s]



Epoch 8/10


Training both models: 100%|██████████| 46/46 [00:17<00:00,  2.69it/s]



Epoch 9/10


Training both models: 100%|██████████| 46/46 [00:16<00:00,  2.73it/s]



Epoch 10/10


Training both models: 100%|██████████| 46/46 [00:16<00:00,  2.72it/s]


Training completed successfully!


0,1
global_step,▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇█
natural_atention/batch/loss,█▆▆▆▆▄▃▄▃▄▄▃▂▄▃▄▄▃▂▃▄▄▃▃▃▂▂▁▃▂▃▂▂▂▃▂▁▁▂▁
natural_atention/batch/perplexity,▇█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
standard/batch/loss,▇█▇▆▅▅▅▄▃▃▄▃▅▃▄▄▃▄▂▂▄▃▃▄▁▂▂▂▂▁▃▃▄▄▃▃▁▄▂▃
standard/batch/perplexity,█▄▄▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██

0,1
global_step,459.0
natural_atention/batch/loss,3.08678
natural_atention/batch/perplexity,21.90648
standard/batch/loss,3.95251
standard/batch/perplexity,52.0659
step,459.0


# Natural attention is natural gradient

In [24]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm

def analyze_attention_training(standard_model, natural_model, train_loader, epochs=5):
    """Collect metrics during training for both models"""
    metrics = {
        'epoch': [], 'batch': [], 'model_type': [],
        'loss': [], 'perplexity': [],
        'attention_energy_mean': [], 'attention_energy_std': [],
        'gradient_norm': []
    }

    for epoch in range(epochs):
        for batch_idx, batch in enumerate(tqdm(train_loader)):
            # Get attention energies from both models
            with torch.no_grad():
                # Standard model metrics
                std_outputs = standard_model(batch)
                std_attention = standard_model.transformer.h[0].attn.last_attention_energies
                std_grad_norm = torch.norm(torch.stack([p.grad.norm() for p in standard_model.parameters() if p.grad is not None]))

                metrics['epoch'].append(epoch)
                metrics['batch'].append(batch_idx)
                metrics['model_type'].append('standard')
                metrics['loss'].append(std_outputs.loss.item())
                metrics['perplexity'].append(torch.exp(std_outputs.loss).item())
                metrics['attention_energy_mean'].append(std_attention.mean().item())
                metrics['attention_energy_std'].append(std_attention.std().item())
                metrics['gradient_norm'].append(std_grad_norm.item())

                # Natural attention model metrics
                nat_outputs = natural_model(batch)
                nat_attention = natural_model.transformer.h[0].attn.last_attention_energies
                nat_grad_norm = torch.norm(torch.stack([p.grad.norm() for p in natural_model.parameters() if p.grad is not None]))

                metrics['epoch'].append(epoch)
                metrics['batch'].append(batch_idx)
                metrics['model_type'].append('natural')
                metrics['loss'].append(nat_outputs.loss.item())
                metrics['perplexity'].append(torch.exp(nat_outputs.loss).item())
                metrics['attention_energy_mean'].append(nat_attention.mean().item())
                metrics['attention_energy_std'].append(nat_attention.std().item())
                metrics['gradient_norm'].append(nat_grad_norm.item())

    return pd.DataFrame(metrics)

def plot_training_metrics(metrics_df):
    """Create a suite of plots comparing model performance"""
    plt.style.use('seaborn')
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # 1. Loss curves
    sns.lineplot(
        data=metrics_df, x='epoch', y='loss', hue='model_type',
        ax=axes[0,0], errorbar='sd'
    )
    axes[0,0].set_title('Training Loss')

    # 2. Attention energy distribution
    sns.boxplot(
        data=metrics_df, x='epoch', y='attention_energy_mean',
        hue='model_type', ax=axes[0,1]
    )
    axes[0,1].set_title('Attention Energy Distribution')

    # 3. Gradient norm evolution
    sns.lineplot(
        data=metrics_df, x='epoch', y='gradient_norm',
        hue='model_type', ax=axes[1,0]
    )
    axes[1,0].set_title('Gradient Norm Evolution')

    # 4. Perplexity comparison
    sns.violinplot(
        data=metrics_df, x='epoch', y='perplexity',
        hue='model_type', ax=axes[1,1], split=True
    )
    axes[1,1].set_title('Perplexity Distribution')

    plt.tight_layout()
    return fig

def analyze_attention_patterns(model, test_loader):
    """Analyze attention pattern stability and structure"""
    attention_patterns = []

    for batch in test_loader:
        with torch.no_grad():
            outputs = model(batch)
            attention = model.transformer.h[0].attn.last_attention_energies
            attention_patterns.append(attention.cpu().numpy())

    patterns = np.stack(attention_patterns)

    # Plot attention pattern analysis
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # 1. Average attention pattern
    sns.heatmap(
        patterns.mean(axis=0).mean(axis=0),
        ax=axes[0], cmap='viridis'
    )
    axes[0].set_title('Average Attention Pattern')

    # 2. Attention stability across samples
    stability = patterns.std(axis=0).mean(axis=0)
    sns.heatmap(stability, ax=axes[1], cmap='rocket')
    axes[1].set_title('Attention Pattern Stability')

    # 3. Attention sparsity distribution
    sparsity = (patterns > patterns.mean() + patterns.std()).mean(axis=(0,1))
    sns.barplot(x=range(len(sparsity)), y=sparsity, ax=axes[2])
    axes[2].set_title('Attention Sparsity by Position')

    plt.tight_layout()
    return fig

# Example usage
def run_analysis(standard_model, natural_model, train_loader, test_loader):
    # Collect training metrics
    metrics_df = analyze_attention_training(standard_model, natural_model, train_loader)

    # Generate plots
    training_fig = plot_training_metrics(metrics_df)
    standard_patterns_fig = analyze_attention_patterns(standard_model, test_loader)
    natural_patterns_fig = analyze_attention_patterns(natural_model, test_loader)

    return {
        'metrics': metrics_df,
        'training_plot': training_fig,
        'standard_patterns': standard_patterns_fig,
        'natural_patterns': natural_patterns_fig
    }

In [25]:
# After training both models
test_loader = DataLoader(
    WikiTextDataset(tokenizer, split='test', max_length=config_dict['max_length']),
    batch_size=config_dict['batch_size'],
    shuffle=False
)

results = run_analysis(standard_model, natural_model, train_loader, test_loader)

# Save the plots
results['training_plot'].savefig('training_comparison.png')
results['standard_patterns'].savefig('standard_attention_patterns.png')
results['natural_patterns'].savefig('natural_attention_patterns.png')

# You can also export metrics to CSV
results['metrics'].to_csv('training_metrics.csv')

NameError: name 'tokenizer' is not defined