# F-Regularization Experiment (Phase 5)

**Goal**: Test causal hypothesis - Does minimizing geDIG F during training improve performance?

- Baseline: standard CrossEntropy fine-tuning
- Treatment: L_total = L_CE + α * F_mean
- α sweep: [0, 0.001, 0.01, 0.1, 1.0] × 3 seeds

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q transformers datasets accelerate

In [None]:
# Inline the training script (to avoid file upload issues)
TRAIN_SCRIPT = '''
#!/usr/bin/env python3
"""
F-Regularized Training Experiment

Tests the causal hypothesis: Does minimizing geDIG F during training improve performance?
"""

from __future__ import annotations

import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.modeling_outputs import SequenceClassifierOutput


@dataclass
class DifferentiableGeDIG:
    """Computes geDIG F in a differentiable manner."""
    lambda_param: float = 1.0
    gamma: float = 0.5
    temperature: float = 0.1
    percentile: float = 0.9
    max_path_length: int = 4

    def compute_F(self, attention: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        batch_size, num_heads, seq_len, _ = attention.shape
        if attention_mask is not None:
            mask_2d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(3)
            attention = attention * mask_2d.float()
        delta_epc = self._compute_soft_density(attention)
        delta_h = self._compute_entropy(attention, attention_mask)
        delta_sp = self._compute_soft_path_efficiency(attention, attention_mask)
        F_values = delta_epc - self.lambda_param * (delta_h + self.gamma * delta_sp)
        return {"F": F_values, "F_mean": F_values.mean(), "delta_epc": delta_epc, "delta_h": delta_h, "delta_sp": delta_sp}

    def _compute_soft_density(self, attention: torch.Tensor) -> torch.Tensor:
        batch_size, num_heads, seq_len, _ = attention.shape
        attn_flat = attention.view(batch_size, num_heads, -1)
        k = int(self.percentile * seq_len * seq_len)
        threshold = torch.kthvalue(attn_flat, k, dim=-1).values.unsqueeze(-1).unsqueeze(-1)
        edge_probs = torch.sigmoid((attention - threshold) / self.temperature)
        return edge_probs.sum(dim=(-2, -1)) / (seq_len * seq_len)

    def _compute_entropy(self, attention: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, num_heads, seq_len, _ = attention.shape
        attn_flat = attention.view(batch_size, num_heads, -1)
        attn_norm = attn_flat / (attn_flat.sum(dim=-1, keepdim=True) + 1e-10)
        entropy = -(attn_norm * torch.log(attn_norm + 1e-10)).sum(dim=-1)
        if attention_mask is not None:
            valid_count = attention_mask.sum(dim=-1).float()
            max_entropy = torch.log(valid_count * valid_count + 1e-10).unsqueeze(1)
        else:
            max_entropy = math.log(seq_len * seq_len)
        return entropy / (max_entropy + 1e-10)

    def _compute_soft_path_efficiency(self, attention: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, num_heads, seq_len, _ = attention.shape
        attn_flat = attention.view(batch_size, num_heads, -1)
        k = int(self.percentile * seq_len * seq_len)
        threshold = torch.kthvalue(attn_flat, k, dim=-1).values.unsqueeze(-1).unsqueeze(-1)
        adj = torch.sigmoid((attention - threshold) / self.temperature)
        eye = torch.eye(seq_len, device=attention.device).unsqueeze(0).unsqueeze(0)
        adj = adj + eye
        path_efficiency = torch.zeros(batch_size, num_heads, device=attention.device)
        adj_power = adj.clone()
        for path_len in range(1, self.max_path_length + 1):
            if path_len > 1:
                adj_power = torch.clamp(torch.matmul(adj_power, adj), 0, 1)
            path_efficiency = path_efficiency + (1.0 / path_len) * (adj_power > 0.5).float().mean(dim=(-2, -1))
        return path_efficiency / self.max_path_length


class FRegularizedModel(nn.Module):
    """Wrapper that adds geDIG F regularization to the loss."""
    def __init__(self, base_model: nn.Module, alpha: float = 0.1, gedig_config: Optional[Dict[str, Any]] = None):
        super().__init__()
        self.base_model = base_model
        self.alpha = alpha
        self.gedig = DifferentiableGeDIG(**(gedig_config or {}))
        self._last_gedig_metrics: Optional[Dict[str, float]] = None

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs) -> SequenceClassifierOutput:
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_attentions=True, **kwargs)
        if labels is not None and self.alpha > 0:
            f_values = [self.gedig.compute_F(layer_attn, attention_mask)["F_mean"] for layer_attn in outputs.attentions]
            f_mean = torch.stack(f_values).mean()
            total_loss = outputs.loss + self.alpha * f_mean
            self._last_gedig_metrics = {"f_mean": f_mean.item(), "ce_loss": outputs.loss.item(), "total_loss": total_loss.item()}
            return SequenceClassifierOutput(loss=total_loss, logits=outputs.logits, hidden_states=None, attentions=None)
        return SequenceClassifierOutput(loss=outputs.loss, logits=outputs.logits, hidden_states=None, attentions=None)


class FRegularizedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        loss = outputs.loss
        if hasattr(model, "_last_gedig_metrics") and model._last_gedig_metrics:
            self.log(model._last_gedig_metrics)
        return (loss, outputs) if return_outputs else loss


def compute_final_gedig_metrics(model, eval_dataset, tokenizer, data_collator):
    from torch.utils.data import DataLoader
    device = next(model.parameters()).device
    model.eval()
    dataloader = DataLoader(eval_dataset, batch_size=32, collate_fn=data_collator)
    gedig = DifferentiableGeDIG()
    all_f, all_epc, all_h, all_sp = [], [], [], []
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            base = model.base_model if hasattr(model, "base_model") else model
            outputs = base(input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask"), output_attentions=True)
            for layer_attn in outputs.attentions:
                metrics = gedig.compute_F(layer_attn, batch.get("attention_mask"))
                all_f.append(metrics["F"].mean().item())
                all_epc.append(metrics["delta_epc"].mean().item())
                all_h.append(metrics["delta_h"].mean().item())
                all_sp.append(metrics["delta_sp"].mean().item())
    return {"f_mean": np.mean(all_f), "f_std": np.std(all_f), "delta_epc_mean": np.mean(all_epc), "delta_h_mean": np.mean(all_h), "delta_sp_mean": np.mean(all_sp)}


def run_experiment(alpha, model_name="distilbert-base-uncased", train_samples=1000, eval_samples=500, epochs=3, batch_size=16, learning_rate=2e-5, seed=42, output_dir=None):
    set_seed(seed)
    if output_dir is None:
        output_dir = Path(f"results/f_reg/alpha_{alpha}_seed_{seed}")
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"\\n{\"="*60}\\nRunning: alpha={alpha}, seed={seed}\\n{\"="*60}")
    
    ds_train = load_dataset("glue", "sst2", split=f"train[:{train_samples}]")
    ds_eval = load_dataset("glue", "sst2", split=f"validation[:{eval_samples}]")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenize_fn = lambda ex: tokenizer(ex["sentence"], truncation=True, max_length=128)
    train_ds = ds_train.map(tokenize_fn, batched=True)
    eval_ds = ds_eval.map(tokenize_fn, batched=True)
    cols = [c for c in train_ds.column_names if c not in ("input_ids", "attention_mask", "label")]
    train_ds = train_ds.remove_columns(cols).with_format("torch")
    eval_ds = eval_ds.remove_columns(cols).with_format("torch")
    
    base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    model = FRegularizedModel(base_model, alpha=alpha) if alpha > 0 else base_model
    
    training_args = TrainingArguments(
        output_dir=str(output_dir), eval_strategy="steps", eval_steps=50, logging_steps=10,
        save_strategy="epoch", per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs, learning_rate=learning_rate, weight_decay=0.01, report_to=[], seed=seed,
    )
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainer = FRegularizedTrainer(
        model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds,
        tokenizer=tokenizer, data_collator=data_collator,
        compute_metrics=lambda p: {"accuracy": (np.argmax(p.predictions, axis=1) == p.label_ids).mean()},
    )
    train_result = trainer.train()
    eval_result = trainer.evaluate()
    final_f = compute_final_gedig_metrics(model, eval_ds, tokenizer, data_collator)
    
    result = {"alpha": alpha, "seed": seed, "final_accuracy": eval_result.get("eval_accuracy"),
              "final_loss": eval_result.get("eval_loss"), "train_runtime": train_result.metrics.get("train_runtime"), "final_gedig": final_f}
    (output_dir / "result.json").write_text(json.dumps(result, indent=2))
    print(f"Result: acc={result[\"final_accuracy\"]:.4f}, F={final_f.get(\"f_mean\", \"N/A\")}")
    return result


def run_alpha_sweep(alphas=[0.0, 0.001, 0.01, 0.1, 1.0], seeds=[42, 123, 456], **kwargs):
    results = []
    for alpha in alphas:
        for seed in seeds:
            results.append(run_experiment(alpha=alpha, seed=seed, **kwargs))
    return results
'''

# Write to file
with open('train_f_reg.py', 'w') as f:
    f.write(TRAIN_SCRIPT)
print('Script written to train_f_reg.py')

In [None]:
# Import and run
import json
import numpy as np
from pathlib import Path

exec(open('train_f_reg.py').read())

# Configuration
ALPHAS = [0.0, 0.001, 0.01, 0.1, 1.0]
SEEDS = [42, 123, 456]
TRAIN_SAMPLES = 2000  # Increase for better results
EVAL_SAMPLES = 500
EPOCHS = 3
BATCH_SIZE = 16

print(f"Running {len(ALPHAS)} alphas x {len(SEEDS)} seeds = {len(ALPHAS)*len(SEEDS)} experiments")

In [None]:
# Run full sweep
results = run_alpha_sweep(
    alphas=ALPHAS,
    seeds=SEEDS,
    train_samples=TRAIN_SAMPLES,
    eval_samples=EVAL_SAMPLES,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
)

# Save all results
Path('results').mkdir(exist_ok=True)
with open('results/all_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print(f"\nSaved {len(results)} results to results/all_results.json")

In [None]:
# Analyze results
import pandas as pd

df = pd.DataFrame(results)

# Summary by alpha
summary = df.groupby('alpha').agg({
    'final_accuracy': ['mean', 'std'],
    'final_loss': ['mean', 'std'],
}).round(4)

print("\n" + "="*60)
print("RESULTS SUMMARY")
print("="*60)
print(summary)

# Best alpha
best_idx = df.groupby('alpha')['final_accuracy'].mean().idxmax()
baseline_acc = df[df['alpha']==0]['final_accuracy'].mean()
best_acc = df[df['alpha']==best_idx]['final_accuracy'].mean()

print(f"\nBaseline (α=0): {baseline_acc:.4f}")
print(f"Best (α={best_idx}): {best_acc:.4f}")
print(f"Improvement: {best_acc - baseline_acc:+.4f}")

In [None]:
# Plot results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Alpha vs Accuracy
grouped = df.groupby('alpha')['final_accuracy'].agg(['mean', 'std']).reset_index()
ax1 = axes[0]
ax1.errorbar(range(len(grouped)), grouped['mean'], yerr=grouped['std'], 
             marker='o', markersize=10, linewidth=2, capsize=5)
ax1.set_xticks(range(len(grouped)))
ax1.set_xticklabels([f"{a}" for a in grouped['alpha']])
ax1.set_xlabel('Alpha (F-regularization weight)')
ax1.set_ylabel('Accuracy')
ax1.set_title('Alpha vs Accuracy')
ax1.axhline(y=baseline_acc, color='gray', linestyle='--', alpha=0.7, label='Baseline')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Alpha vs Final F
f_data = []
for r in results:
    if r.get('final_gedig'):
        f_data.append({'alpha': r['alpha'], 'f_mean': r['final_gedig']['f_mean']})
f_df = pd.DataFrame(f_data)
f_grouped = f_df.groupby('alpha')['f_mean'].agg(['mean', 'std']).reset_index()

ax2 = axes[1]
ax2.errorbar(range(len(f_grouped)), f_grouped['mean'], yerr=f_grouped['std'],
             marker='s', markersize=10, linewidth=2, capsize=5, color='orange')
ax2.set_xticks(range(len(f_grouped)))
ax2.set_xticklabels([f"{a}" for a in f_grouped['alpha']])
ax2.set_xlabel('Alpha (F-regularization weight)')
ax2.set_ylabel('Final F (geDIG)')
ax2.set_title('Alpha vs Final F')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('results/fig_f_reg_summary.png', dpi=150)
plt.show()
print("Saved: results/fig_f_reg_summary.png")

In [None]:
# Download results
from google.colab import files

# Zip results
!zip -r f_reg_results.zip results/
files.download('f_reg_results.zip')

## Interpretation

**Success criteria**:
1. α > 0 outperforms baseline (α=0) → F-regularization helps
2. Optimal α exists (not monotonic) → there's a sweet spot
3. Final F is lower for regularized models → F is being minimized

**If successful**: geDIG F is not just correlated with good attention, but causally contributes to it.