# F-Regularization Large-Scale Experiment

## Goal
Validate the causal hypothesis at scale: **Does minimizing geDIG F during training improve performance across multiple models and tasks?**

## Experiment Matrix
- **Models**: DistilBERT, BERT-base, RoBERTa-base
- **Tasks**: SST-2, MRPC, CoLA, QNLI (GLUE subset)
- **α sweep**: [0, 0.001, 0.01, 0.1]
- **Seeds**: [42, 123, 456, 789, 1024]

## Expected Runtime
- Full sweep: ~8-12 hours on T4/V100
- Single task/model: ~30-60 min

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q transformers datasets accelerate scipy

In [None]:
# Core imports
from __future__ import annotations

import json
import math
import os
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from scipy import stats
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.modeling_outputs import SequenceClassifierOutput

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================================================
# Differentiable geDIG Calculator
# ============================================================================

@dataclass
class DifferentiableGeDIG:
    """Computes geDIG F in a differentiable manner for backpropagation."""
    lambda_param: float = 1.0
    gamma: float = 0.5
    temperature: float = 0.1
    percentile: float = 0.9
    max_path_length: int = 4

    def compute_F(self, attention: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        batch_size, num_heads, seq_len, _ = attention.shape
        if attention_mask is not None:
            mask_2d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(3)
            attention = attention * mask_2d.float()
        delta_epc = self._compute_soft_density(attention)
        delta_h = self._compute_entropy(attention, attention_mask)
        delta_sp = self._compute_soft_path_efficiency(attention, attention_mask)
        F_values = delta_epc - self.lambda_param * (delta_h + self.gamma * delta_sp)
        return {"F": F_values, "F_mean": F_values.mean(), "delta_epc": delta_epc, "delta_h": delta_h, "delta_sp": delta_sp}

    def _compute_soft_density(self, attention: torch.Tensor) -> torch.Tensor:
        batch_size, num_heads, seq_len, _ = attention.shape
        attn_flat = attention.view(batch_size, num_heads, -1)
        k = int(self.percentile * seq_len * seq_len)
        threshold = torch.kthvalue(attn_flat, k, dim=-1).values.unsqueeze(-1).unsqueeze(-1)
        edge_probs = torch.sigmoid((attention - threshold) / self.temperature)
        return edge_probs.sum(dim=(-2, -1)) / (seq_len * seq_len)

    def _compute_entropy(self, attention: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, num_heads, seq_len, _ = attention.shape
        attn_flat = attention.view(batch_size, num_heads, -1)
        attn_norm = attn_flat / (attn_flat.sum(dim=-1, keepdim=True) + 1e-10)
        entropy = -(attn_norm * torch.log(attn_norm + 1e-10)).sum(dim=-1)
        if attention_mask is not None:
            valid_count = attention_mask.sum(dim=-1).float()
            max_entropy = torch.log(valid_count * valid_count + 1e-10).unsqueeze(1)
        else:
            max_entropy = math.log(seq_len * seq_len)
        return entropy / (max_entropy + 1e-10)

    def _compute_soft_path_efficiency(self, attention: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, num_heads, seq_len, _ = attention.shape
        attn_flat = attention.view(batch_size, num_heads, -1)
        k = int(self.percentile * seq_len * seq_len)
        threshold = torch.kthvalue(attn_flat, k, dim=-1).values.unsqueeze(-1).unsqueeze(-1)
        adj = torch.sigmoid((attention - threshold) / self.temperature)
        eye = torch.eye(seq_len, device=attention.device).unsqueeze(0).unsqueeze(0)
        adj = adj + eye
        path_efficiency = torch.zeros(batch_size, num_heads, device=attention.device)
        adj_power = adj.clone()
        for path_len in range(1, self.max_path_length + 1):
            if path_len > 1:
                adj_power = torch.clamp(torch.matmul(adj_power, adj), 0, 1)
            path_efficiency = path_efficiency + (1.0 / path_len) * (adj_power > 0.5).float().mean(dim=(-2, -1))
        return path_efficiency / self.max_path_length

In [None]:
# ============================================================================
# F-Regularized Model and Trainer
# ============================================================================

class FRegularizedModel(nn.Module):
    """Wrapper that adds geDIG F regularization to the loss."""
    def __init__(self, base_model: nn.Module, alpha: float = 0.1, gedig_config: Optional[Dict[str, Any]] = None):
        super().__init__()
        self.base_model = base_model
        self.alpha = alpha
        self.gedig = DifferentiableGeDIG(**(gedig_config or {}))
        self._last_gedig_metrics: Optional[Dict[str, float]] = None

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, 
                labels: Optional[torch.Tensor] = None, **kwargs) -> SequenceClassifierOutput:
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, 
                                   labels=labels, output_attentions=True, **kwargs)
        if labels is not None and self.alpha > 0:
            f_values = [self.gedig.compute_F(layer_attn, attention_mask)["F_mean"] 
                       for layer_attn in outputs.attentions]
            f_mean = torch.stack(f_values).mean()
            total_loss = outputs.loss + self.alpha * f_mean
            self._last_gedig_metrics = {
                "f_mean": f_mean.item(), 
                "ce_loss": outputs.loss.item(), 
                "total_loss": total_loss.item()
            }
            return SequenceClassifierOutput(loss=total_loss, logits=outputs.logits, 
                                            hidden_states=None, attentions=None)
        return SequenceClassifierOutput(loss=outputs.loss, logits=outputs.logits, 
                                        hidden_states=None, attentions=None)


class FRegularizedTrainer(Trainer):
    """Trainer with geDIG metric logging."""
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        loss = outputs.loss
        if hasattr(model, "_last_gedig_metrics") and model._last_gedig_metrics:
            self.log(model._last_gedig_metrics)
        return (loss, outputs) if return_outputs else loss

In [None]:
# ============================================================================
# Task Configurations
# ============================================================================

TASK_CONFIGS = {
    "sst2": {
        "dataset": ("glue", "sst2"),
        "text_field": "sentence",
        "num_labels": 2,
        "metric": "accuracy",
    },
    "mrpc": {
        "dataset": ("glue", "mrpc"),
        "text_field": ["sentence1", "sentence2"],
        "num_labels": 2,
        "metric": "f1",
    },
    "cola": {
        "dataset": ("glue", "cola"),
        "text_field": "sentence",
        "num_labels": 2,
        "metric": "matthews_correlation",
    },
    "qnli": {
        "dataset": ("glue", "qnli"),
        "text_field": ["question", "sentence"],
        "num_labels": 2,
        "metric": "accuracy",
    },
}

MODEL_CONFIGS = {
    "distilbert": "distilbert-base-uncased",
    "bert": "bert-base-uncased",
    "roberta": "roberta-base",
}

# Experiment settings
ALPHAS = [0.0, 0.001, 0.01, 0.1]
SEEDS = [42, 123, 456, 789, 1024]

In [None]:
# ============================================================================
# Metrics Computation
# ============================================================================

from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef

def compute_metrics(pred, metric_name="accuracy"):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    
    if metric_name == "accuracy":
        return {"accuracy": accuracy_score(labels, preds)}
    elif metric_name == "f1":
        return {
            "f1": f1_score(labels, preds),
            "accuracy": accuracy_score(labels, preds),
        }
    elif metric_name == "matthews_correlation":
        return {
            "matthews_correlation": matthews_corrcoef(labels, preds),
            "accuracy": accuracy_score(labels, preds),
        }
    return {"accuracy": accuracy_score(labels, preds)}


def compute_final_gedig_metrics(model, eval_dataset, tokenizer, data_collator):
    """Compute geDIG metrics on eval set."""
    from torch.utils.data import DataLoader
    device = next(model.parameters()).device
    model.eval()
    dataloader = DataLoader(eval_dataset, batch_size=32, collate_fn=data_collator)
    gedig = DifferentiableGeDIG()
    all_f = []
    
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            base = model.base_model if hasattr(model, "base_model") else model
            outputs = base(input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask"), 
                          output_attentions=True)
            for layer_attn in outputs.attentions:
                metrics = gedig.compute_F(layer_attn, batch.get("attention_mask"))
                all_f.append(metrics["F"].mean().item())
    
    return {"f_mean": np.mean(all_f), "f_std": np.std(all_f)}

In [None]:
# ============================================================================
# Single Experiment Runner
# ============================================================================

def run_single_experiment(
    model_name: str,
    task_name: str,
    alpha: float,
    seed: int,
    max_train_samples: Optional[int] = None,
    max_eval_samples: Optional[int] = None,
    epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 2e-5,
    output_dir: Optional[Path] = None,
) -> Dict[str, Any]:
    """Run a single F-regularization experiment."""
    set_seed(seed)
    
    task_config = TASK_CONFIGS[task_name]
    model_path = MODEL_CONFIGS[model_name]
    
    if output_dir is None:
        output_dir = Path(f"results/{model_name}/{task_name}/alpha_{alpha}_seed_{seed}")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*70}")
    print(f"Model: {model_name} | Task: {task_name} | α: {alpha} | Seed: {seed}")
    print(f"{'='*70}")
    
    start_time = time.time()
    
    # Load dataset
    ds_name, ds_config = task_config["dataset"]
    train_split = "train" if max_train_samples is None else f"train[:{max_train_samples}]"
    eval_split = "validation" if max_eval_samples is None else f"validation[:{max_eval_samples}]"
    
    ds_train = load_dataset(ds_name, ds_config, split=train_split)
    ds_eval = load_dataset(ds_name, ds_config, split=eval_split)
    
    print(f"Train: {len(ds_train)} samples | Eval: {len(ds_eval)} samples")
    
    # Tokenize
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    text_field = task_config["text_field"]
    
    if isinstance(text_field, list):
        tokenize_fn = lambda ex: tokenizer(ex[text_field[0]], ex[text_field[1]], 
                                           truncation=True, max_length=128)
    else:
        tokenize_fn = lambda ex: tokenizer(ex[text_field], truncation=True, max_length=128)
    
    train_ds = ds_train.map(tokenize_fn, batched=True)
    eval_ds = ds_eval.map(tokenize_fn, batched=True)
    
    # Remove unused columns
    keep_cols = {"input_ids", "attention_mask", "label"}
    train_ds = train_ds.remove_columns([c for c in train_ds.column_names if c not in keep_cols])
    eval_ds = eval_ds.remove_columns([c for c in eval_ds.column_names if c not in keep_cols])
    train_ds = train_ds.with_format("torch")
    eval_ds = eval_ds.with_format("torch")
    
    # Load model
    base_model = AutoModelForSequenceClassification.from_pretrained(
        model_path, num_labels=task_config["num_labels"]
    )
    model = FRegularizedModel(base_model, alpha=alpha) if alpha > 0 else base_model
    
    # Training args
    training_args = TrainingArguments(
        output_dir=str(output_dir),
        eval_strategy="epoch",
        logging_steps=50,
        save_strategy="no",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        learning_rate=learning_rate,
        weight_decay=0.01,
        report_to=[],
        seed=seed,
        fp16=torch.cuda.is_available(),
    )
    
    # Trainer
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    metric_name = task_config["metric"]
    
    trainer = FRegularizedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda p: compute_metrics(p, metric_name),
    )
    
    # Train
    train_result = trainer.train()
    eval_result = trainer.evaluate()
    
    # Final geDIG metrics
    final_f = compute_final_gedig_metrics(model, eval_ds, tokenizer, data_collator)
    
    elapsed = time.time() - start_time
    
    # Compile result
    result = {
        "model": model_name,
        "task": task_name,
        "alpha": alpha,
        "seed": seed,
        "train_samples": len(ds_train),
        "eval_samples": len(ds_eval),
        "epochs": epochs,
        "metric_name": metric_name,
        "eval_metric": eval_result.get(f"eval_{metric_name}"),
        "eval_accuracy": eval_result.get("eval_accuracy"),
        "eval_loss": eval_result.get("eval_loss"),
        "final_f_mean": final_f["f_mean"],
        "final_f_std": final_f["f_std"],
        "runtime_seconds": elapsed,
    }
    
    # Save
    (output_dir / "result.json").write_text(json.dumps(result, indent=2))
    
    print(f"Result: {metric_name}={result['eval_metric']:.4f}, F={final_f['f_mean']:.4f}, time={elapsed:.1f}s")
    
    return result

In [None]:
# ============================================================================
# Large-Scale Experiment Runner
# ============================================================================

def run_large_scale_experiment(
    models: List[str] = ["distilbert", "bert", "roberta"],
    tasks: List[str] = ["sst2", "mrpc", "cola", "qnli"],
    alphas: List[float] = ALPHAS,
    seeds: List[int] = SEEDS,
    max_train_samples: Optional[int] = None,  # None = full dataset
    max_eval_samples: Optional[int] = None,
    epochs: int = 3,
    output_dir: Path = Path("results"),
) -> List[Dict[str, Any]]:
    """Run large-scale F-regularization experiment."""
    
    total_experiments = len(models) * len(tasks) * len(alphas) * len(seeds)
    print(f"\n{'#'*70}")
    print(f"# LARGE-SCALE F-REGULARIZATION EXPERIMENT")
    print(f"# Models: {models}")
    print(f"# Tasks: {tasks}")
    print(f"# Alphas: {alphas}")
    print(f"# Seeds: {seeds}")
    print(f"# Total experiments: {total_experiments}")
    print(f"{'#'*70}\n")
    
    all_results = []
    experiment_idx = 0
    
    for model_name in models:
        for task_name in tasks:
            for alpha in alphas:
                for seed in seeds:
                    experiment_idx += 1
                    print(f"\n[{experiment_idx}/{total_experiments}]")
                    
                    try:
                        result = run_single_experiment(
                            model_name=model_name,
                            task_name=task_name,
                            alpha=alpha,
                            seed=seed,
                            max_train_samples=max_train_samples,
                            max_eval_samples=max_eval_samples,
                            epochs=epochs,
                            output_dir=output_dir / model_name / task_name / f"alpha_{alpha}_seed_{seed}",
                        )
                        all_results.append(result)
                        
                        # Save intermediate results
                        output_dir.mkdir(parents=True, exist_ok=True)
                        (output_dir / "all_results_partial.json").write_text(
                            json.dumps(all_results, indent=2)
                        )
                        
                    except Exception as e:
                        print(f"ERROR: {e}")
                        all_results.append({
                            "model": model_name, "task": task_name, 
                            "alpha": alpha, "seed": seed, "error": str(e)
                        })
    
    # Save final results
    (output_dir / "all_results.json").write_text(json.dumps(all_results, indent=2))
    print(f"\nSaved {len(all_results)} results to {output_dir / 'all_results.json'}")
    
    return all_results

In [None]:
# ============================================================================
# Statistical Analysis
# ============================================================================

def analyze_results(results: List[Dict], output_dir: Path = Path("results")):
    """Comprehensive statistical analysis of experiment results."""
    
    df = pd.DataFrame([r for r in results if "error" not in r])
    
    print("\n" + "="*70)
    print("STATISTICAL ANALYSIS")
    print("="*70)
    
    # 1. Overall summary by alpha
    print("\n### Overall Summary by Alpha ###")
    overall = df.groupby("alpha").agg({
        "eval_accuracy": ["mean", "std", "count"],
        "final_f_mean": ["mean", "std"],
    }).round(4)
    print(overall)
    
    # 2. Per-task analysis
    print("\n### Per-Task Summary ###")
    for task in df["task"].unique():
        print(f"\n--- {task.upper()} ---")
        task_df = df[df["task"] == task]
        task_summary = task_df.groupby("alpha").agg({
            "eval_metric": ["mean", "std"],
        }).round(4)
        print(task_summary)
    
    # 3. Per-model analysis
    print("\n### Per-Model Summary ###")
    for model in df["model"].unique():
        print(f"\n--- {model.upper()} ---")
        model_df = df[df["model"] == model]
        model_summary = model_df.groupby("alpha").agg({
            "eval_accuracy": ["mean", "std"],
        }).round(4)
        print(model_summary)
    
    # 4. Statistical tests (t-test: best alpha vs baseline)
    print("\n### Statistical Significance Tests ###")
    baseline_df = df[df["alpha"] == 0.0]
    
    for alpha in [a for a in df["alpha"].unique() if a > 0]:
        treatment_df = df[df["alpha"] == alpha]
        
        baseline_acc = baseline_df["eval_accuracy"].values
        treatment_acc = treatment_df["eval_accuracy"].values
        
        if len(baseline_acc) > 1 and len(treatment_acc) > 1:
            t_stat, p_value = stats.ttest_ind(treatment_acc, baseline_acc)
            effect_size = (treatment_acc.mean() - baseline_acc.mean()) / np.sqrt(
                (baseline_acc.std()**2 + treatment_acc.std()**2) / 2
            )
            
            print(f"\nα={alpha} vs α=0 (baseline):")
            print(f"  Baseline: {baseline_acc.mean():.4f} ± {baseline_acc.std():.4f}")
            print(f"  Treatment: {treatment_acc.mean():.4f} ± {treatment_acc.std():.4f}")
            print(f"  Improvement: {(treatment_acc.mean() - baseline_acc.mean())*100:+.2f}%")
            print(f"  t-statistic: {t_stat:.3f}")
            print(f"  p-value: {p_value:.4f} {'***' if p_value < 0.001 else '**' if p_value < 0.01 else '*' if p_value < 0.05 else ''}")
            print(f"  Cohen's d: {effect_size:.3f}")
    
    # 5. Find best configuration
    print("\n### Best Configurations ###")
    best_overall = df.groupby(["model", "task", "alpha"])["eval_metric"].mean().reset_index()
    for task in df["task"].unique():
        task_best = best_overall[best_overall["task"] == task]
        best_row = task_best.loc[task_best["eval_metric"].idxmax()]
        baseline_row = task_best[(task_best["alpha"] == 0.0)]
        if not baseline_row.empty:
            baseline_val = baseline_row["eval_metric"].mean()
            improvement = (best_row["eval_metric"] - baseline_val) * 100
            print(f"{task}: Best α={best_row['alpha']} ({best_row['model']}), "
                  f"metric={best_row['eval_metric']:.4f}, improvement={improvement:+.2f}%")
    
    # Save analysis
    analysis = {
        "overall_summary": overall.to_dict(),
        "timestamp": datetime.now().isoformat(),
        "total_experiments": len(df),
    }
    (output_dir / "analysis.json").write_text(json.dumps(analysis, indent=2, default=str))
    
    return df

In [None]:
# ============================================================================
# Visualization
# ============================================================================

import matplotlib.pyplot as plt

def plot_results(df: pd.DataFrame, output_dir: Path = Path("results")):
    """Generate comprehensive visualization of results."""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Overall Alpha vs Accuracy
    ax = axes[0, 0]
    overall = df.groupby("alpha")["eval_accuracy"].agg(["mean", "std"]).reset_index()
    ax.errorbar(range(len(overall)), overall["mean"], yerr=overall["std"],
                marker="o", markersize=10, linewidth=2, capsize=5)
    ax.set_xticks(range(len(overall)))
    ax.set_xticklabels([f"{a}" for a in overall["alpha"]])
    ax.set_xlabel("Alpha")
    ax.set_ylabel("Accuracy")
    ax.set_title("Overall: Alpha vs Accuracy")
    ax.grid(True, alpha=0.3)
    baseline = overall[overall["alpha"] == 0]["mean"].values[0]
    ax.axhline(y=baseline, color="gray", linestyle="--", alpha=0.7, label="Baseline")
    ax.legend()
    
    # 2. Per-Task Alpha vs Metric
    ax = axes[0, 1]
    for task in df["task"].unique():
        task_df = df[df["task"] == task]
        task_summary = task_df.groupby("alpha")["eval_metric"].mean().reset_index()
        ax.plot(range(len(task_summary)), task_summary["eval_metric"], 
                marker="o", label=task, linewidth=2)
    ax.set_xticks(range(len(ALPHAS)))
    ax.set_xticklabels([f"{a}" for a in ALPHAS])
    ax.set_xlabel("Alpha")
    ax.set_ylabel("Task Metric")
    ax.set_title("Per-Task: Alpha vs Metric")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 3. Per-Model Alpha vs Accuracy
    ax = axes[0, 2]
    for model in df["model"].unique():
        model_df = df[df["model"] == model]
        model_summary = model_df.groupby("alpha")["eval_accuracy"].mean().reset_index()
        ax.plot(range(len(model_summary)), model_summary["eval_accuracy"],
                marker="s", label=model, linewidth=2)
    ax.set_xticks(range(len(ALPHAS)))
    ax.set_xticklabels([f"{a}" for a in ALPHAS])
    ax.set_xlabel("Alpha")
    ax.set_ylabel("Accuracy")
    ax.set_title("Per-Model: Alpha vs Accuracy")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Alpha vs Final F
    ax = axes[1, 0]
    f_summary = df.groupby("alpha")["final_f_mean"].agg(["mean", "std"]).reset_index()
    ax.errorbar(range(len(f_summary)), f_summary["mean"], yerr=f_summary["std"],
                marker="s", markersize=10, linewidth=2, capsize=5, color="orange")
    ax.set_xticks(range(len(f_summary)))
    ax.set_xticklabels([f"{a}" for a in f_summary["alpha"]])
    ax.set_xlabel("Alpha")
    ax.set_ylabel("Final F (geDIG)")
    ax.set_title("Alpha vs Final F")
    ax.grid(True, alpha=0.3)
    
    # 5. Accuracy vs F scatter (correlation)
    ax = axes[1, 1]
    scatter = ax.scatter(df["final_f_mean"], df["eval_accuracy"], 
                         c=[ALPHAS.index(a) for a in df["alpha"]], 
                         cmap="viridis", alpha=0.6, s=50)
    # Trend line
    z = np.polyfit(df["final_f_mean"], df["eval_accuracy"], 1)
    p = np.poly1d(z)
    x_range = np.linspace(df["final_f_mean"].min(), df["final_f_mean"].max(), 100)
    corr = np.corrcoef(df["final_f_mean"], df["eval_accuracy"])[0, 1]
    ax.plot(x_range, p(x_range), "r--", alpha=0.5, label=f"r={corr:.3f}")
    ax.set_xlabel("Final F (geDIG)")
    ax.set_ylabel("Accuracy")
    ax.set_title(f"Accuracy vs F Correlation (r={corr:.3f})")
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.colorbar(scatter, ax=ax, label="Alpha index")
    
    # 6. Improvement heatmap (model x task)
    ax = axes[1, 2]
    # Calculate improvement for best alpha vs baseline
    improvements = []
    for model in df["model"].unique():
        row = []
        for task in df["task"].unique():
            subset = df[(df["model"] == model) & (df["task"] == task)]
            baseline = subset[subset["alpha"] == 0]["eval_metric"].mean()
            best = subset.groupby("alpha")["eval_metric"].mean().max()
            improvement = (best - baseline) * 100
            row.append(improvement)
        improvements.append(row)
    
    im = ax.imshow(improvements, cmap="RdYlGn", aspect="auto", vmin=-2, vmax=2)
    ax.set_xticks(range(len(df["task"].unique())))
    ax.set_xticklabels(df["task"].unique())
    ax.set_yticks(range(len(df["model"].unique())))
    ax.set_yticklabels(df["model"].unique())
    ax.set_title("Improvement (%) vs Baseline")
    plt.colorbar(im, ax=ax, label="Improvement %")
    
    # Add text annotations
    for i in range(len(df["model"].unique())):
        for j in range(len(df["task"].unique())):
            ax.text(j, i, f"{improvements[i][j]:.2f}", ha="center", va="center", fontsize=10)
    
    plt.tight_layout()
    plt.savefig(output_dir / "fig_large_scale_results.png", dpi=150)
    plt.show()
    print(f"Saved: {output_dir / 'fig_large_scale_results.png'}")

---
# EXPERIMENT EXECUTION
---

In [None]:
# ============================================================================
# OPTION 1: Quick Test (single model, single task)
# Runtime: ~10-15 min
# ============================================================================

QUICK_TEST = True  # Set to False for full experiment

if QUICK_TEST:
    results = run_large_scale_experiment(
        models=["distilbert"],
        tasks=["sst2"],
        alphas=[0.0, 0.001, 0.01],
        seeds=[42, 123],
        max_train_samples=2000,
        max_eval_samples=500,
        epochs=2,
        output_dir=Path("results_quick"),
    )

In [None]:
# ============================================================================
# OPTION 2: Medium Scale (all tasks, one model)
# Runtime: ~2-3 hours
# ============================================================================

MEDIUM_SCALE = False  # Set to True to run

if MEDIUM_SCALE:
    results = run_large_scale_experiment(
        models=["distilbert"],
        tasks=["sst2", "mrpc", "cola", "qnli"],
        alphas=ALPHAS,
        seeds=[42, 123, 456],
        max_train_samples=5000,
        epochs=3,
        output_dir=Path("results_medium"),
    )

In [None]:
# ============================================================================
# OPTION 3: Full Scale Experiment
# Runtime: ~8-12 hours (recommend A100/V100)
# ============================================================================

FULL_SCALE = False  # Set to True to run

if FULL_SCALE:
    results = run_large_scale_experiment(
        models=["distilbert", "bert", "roberta"],
        tasks=["sst2", "mrpc", "cola", "qnli"],
        alphas=ALPHAS,
        seeds=SEEDS,
        max_train_samples=None,  # Full dataset
        max_eval_samples=None,
        epochs=3,
        output_dir=Path("results_full"),
    )

In [None]:
# ============================================================================
# Analyze and Visualize Results
# ============================================================================

# Load results (adjust path based on which experiment you ran)
result_dir = Path("results_quick")  # or results_medium, results_full

if (result_dir / "all_results.json").exists():
    with open(result_dir / "all_results.json") as f:
        results = json.load(f)
    
    df = analyze_results(results, result_dir)
    plot_results(df, result_dir)
else:
    print(f"No results found in {result_dir}")

In [None]:
# ============================================================================
# Download Results
# ============================================================================

from google.colab import files

# Adjust based on which experiment you ran
result_dir = "results_quick"  # or results_medium, results_full

!zip -r f_reg_large_scale_results.zip {result_dir}/
files.download('f_reg_large_scale_results.zip')

---
## Interpretation Guide

### Success Criteria for "やばい" (Breakthrough) Level

| Criterion | Threshold | Status |
|-----------|-----------|--------|
| Consistent improvement | α>0 beats baseline in >75% of settings | ? |
| Statistical significance | p < 0.01 for best α vs baseline | ? |
| Effect size | Cohen's d > 0.3 (medium effect) | ? |
| Cross-model generalization | Works on BERT, RoBERTa, DistilBERT | ? |
| Cross-task generalization | Works on SST-2, MRPC, CoLA, QNLI | ? |

### If Successful
- geDIG F is a **trainable objective** for Transformer optimization
- Opens path to **Attention-free architectures** based on graph principles
- Publishable at ACL/EMNLP/NeurIPS level