In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate


Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
Collecting datasets==2.18.0
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.18.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.2.0-py3-none-any.whl (170 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow_hotfix-0.7-py3-none-any.whl (7.9 kB)
Installing collected packages: pyarrow-hotfix, fsspec, datasets
  Attempting uninstall: fsspe

In [None]:
num = 5

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Finite-Difference JVP — Prune ONLY decoder attention layers (self & cross)
# NaN-hardened version

import inspect
import warnings
import math
from collections import defaultdict

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# =========================
# 0) Helpers for FD-JVP and kwargs/mask handling
# =========================

def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            if i < len(args):
                d[name] = args[i]
            else:
                d[name] = sig.parameters[name].default
    d.pop("hidden_states", None)
    return d  # filtering happens at call-time

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # stable large negative instead of -inf to avoid NaNs in softmax
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _build_zero_mask(B, q_len, k_len, device, dtype):
    return torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self, B_fallback=1):
    """
    Return a (B,1,q_len,k_len) mask; if input mask is missing or wrong-sized,
    rebuild a safe default (causal for self, zeros for cross).
    """
    default_fn = _build_causal_mask if is_self else _build_zero_mask

    if mask is None:
        return default_fn(B_fallback, q_len, k_len, device, dtype)

    if mask.dim() == 2:
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B, 1, q_len, k_len).contiguous().to(device=device, dtype=dtype)

    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        # cannot slice up to desired size → rebuild
        return default_fn(B, q_len, k_len, device, dtype)

    # unknown shape → rebuild
    return default_fn(B_fallback, q_len, k_len, device, dtype)

@torch.no_grad()
def _t5sublayer_forward_only(module, X, argdict):
    """
    Deterministic forward for a single T5 sublayer (Self/Cross Attention).
    Returns only hidden_states (tensor). fp32, autocast disabled.
    Provides safe masks and cache_position/query_length when missing.
    """
    if X is None:
        return None
    was_training = module.training
    module.eval()
    Xf = X.float()
    try:
        raw_kwargs = dict(argdict) if argdict is not None else {}
        raw_kwargs.pop("position_bias", None)  # force recompute inside

        sig = inspect.signature(module.forward)
        params = set(sig.parameters.keys())

        # Build kwargs to re-call the sublayer
        kwargs = {"hidden_states": Xf}

        # Cross-attention?
        is_cross = "key_value_states" in params
        kv = None
        if is_cross:
            kv = raw_kwargs.get("key_value_states", raw_kwargs.get("encoder_hidden_states", None))
            if kv is None:
                # no encoder states captured → skip this probe (avoid wrong shapes)
                return None
            kwargs["key_value_states"] = kv.float()

        # Determine lengths
        q_len = Xf.size(1)
        k_len = kv.size(1) if (is_cross and kv is not None) else q_len
        B = Xf.size(0)

        # Attention mask (name differs across wrappers)
        mask_key = "attention_mask" if "attention_mask" in params else ("mask" if "mask" in params else None)
        if mask_key is not None:
            src = raw_kwargs.get(mask_key, raw_kwargs.get("mask", None))
            mask = _ensure_mask_shape(src, q_len, k_len, device=Xf.device, dtype=Xf.dtype, is_self=not is_cross, B_fallback=B)
            kwargs[mask_key] = mask

        # cache_position & query_length fixes
        if "cache_position" in params:
            cp = raw_kwargs.get("cache_position", None)
            if cp is None:
                cp = torch.arange(q_len, dtype=torch.long, device=Xf.device)
            kwargs["cache_position"] = cp
        if "query_length" in params and "cache_position" not in params:
            kwargs["query_length"] = q_len

        # Deterministic flags if present
        if "use_cache" in params:
            kwargs["use_cache"] = False
        if "output_attentions" in params:
            kwargs["output_attentions"] = False

        # Filter to exactly what this module accepts
        kwargs = _filter_kwargs_for_module(module, kwargs)

        # Disable autocast for stable finite differences
        with torch.cuda.amp.autocast(enabled=False):
            out = module(**kwargs)

        # T5Layer{Self|Cross}Attention returns (hs, present_kv, position_bias)
        hs = out[0] if isinstance(out, (tuple, list)) else out
        return hs if hs is None or torch.isfinite(hs).all() else None
    finally:
        module.train(was_training)

@torch.no_grad()
def compute_decoder_attention_jdev(model, self_bufs, cross_bufs, eps=1e-3, k_probes=1):
    """
    For each decoder block sublayer ℓ (self-attn, cross-attn), estimate:
      E_v ||(Jℓ - I)v||^2 ≈ E_v || (f(x + eps v) - f(x))/eps - v ||^2
    Returns two dicts: self_scores{idx:score}, cross_scores{idx:score}.
    NaN-safe: skips any non-finite forward/probe results. Buffers cleared.
    """
    self_scores, cross_scores = {}, {}

    # Self-attention J-dev
    for idx, buf in self_bufs.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue
        mod = model.decoder.block[idx].layer[0]  # T5LayerSelfAttention
        X = X.float()
        y0 = _t5sublayer_forward_only(mod, X, args)
        if y0 is None:
            buf["X"], buf["args"] = None, None
            continue
        acc, used = 0.0, 0
        for _ in range(max(1, k_probes)):
            v = torch.randn_like(X)
            y_eps = _t5sublayer_forward_only(mod, X + eps * v, args)
            if y_eps is None:
                continue
            jd_vec = (y_eps - y0) / eps - v
            if not torch.isfinite(jd_vec).all():
                continue
            acc += float(jd_vec.pow(2).mean().item())
            used += 1
        if used > 0:
            self_scores[idx] = acc / used
        buf["X"], buf["args"] = None, None

    # Cross-attention J-dev
    for idx, buf in cross_bufs.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue
        if len(model.decoder.block[idx].layer) <= 1 or model.decoder.block[idx].layer[1] is None:
            continue
        mod = model.decoder.block[idx].layer[1]  # T5LayerCrossAttention
        X = X.float()
        y0 = _t5sublayer_forward_only(mod, X, args)
        if y0 is None:
            buf["X"], buf["args"] = None, None
            continue
        acc, used = 0.0, 0
        for _ in range(max(1, k_probes)):
            v = torch.randn_like(X)
            y_eps = _t5sublayer_forward_only(mod, X + eps * v, args)
            if y_eps is None:
                continue
            jd_vec = (y_eps - y0) / eps - v
            if not torch.isfinite(jd_vec).all():
                continue
            acc += float(jd_vec.pow(2).mean().item())
            used += 1
        if used > 0:
            cross_scores[idx] = acc / used
        buf["X"], buf["args"] = None, None

    return self_scores, cross_scores

# =========================
# 1) Hook Utilities — capture ONLY decoder attention sublayer inputs/args
# =========================

def register_decoder_attention_hooks(model):
    """
    Capture inputs/kwargs for decoder self-attention (layer[0]) and cross-attention (layer[1]).
    """
    dec_blocks = model.decoder.block
    self_bufs  = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    cross_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    hooks = []

    # Self-attention pre-hooks
    for i, block in enumerate(dec_blocks):
        sa = block.layer[0]
        def pre_hook_sa(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            self_bufs[idx]["X"]    = None if X is None else X.detach()
            self_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        hooks.append(sa.register_forward_pre_hook(pre_hook_sa, with_kwargs=True))

    # Cross-attention pre-hooks
    for i, block in enumerate(dec_blocks):
        if len(block.layer) > 1 and block.layer[1] is not None:
            ca = block.layer[1]
            def pre_hook_ca(module, args, kwargs, idx=i):
                X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
                cross_bufs[idx]["X"]    = None if X is None else X.detach()
                cross_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
            hooks.append(ca.register_forward_pre_hook(pre_hook_ca, with_kwargs=True))

    return hooks, self_bufs, cross_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# =========================
# 2) Attention-only pruning utilities
# =========================

class SkipSelfAttention(nn.Module):
    """
    Identity replacement for T5LayerSelfAttention.
    IMPORTANT: return (hidden_states, None, None) to avoid propagating stale position bias.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

class SkipCrossAttention(nn.Module):
    """
    Identity replacement for T5LayerCrossAttention.
    Also return None for position_bias to force recomputation downstream.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

def prune_decoder_attention(model, self_scores, cross_scores,
                            k_self=2, k_cross=2,
                            protect_first=True, protect_last=False, verbose=True):
    """
    Prune LOW J-dev attention sublayers (near-identity). FFNs untouched.
    """
    num_layers = len(model.decoder.block)
    self_items  = list(self_scores.items())
    cross_items = list(cross_scores.items())

    if protect_first:
        self_items  = [(i,s) for i,s in self_items  if i != 0]
        cross_items = [(i,s) for i,s in cross_items if i != 0]
    if protect_last:
        self_items  = [(i,s) for i,s in self_items  if i != num_layers-1]
        cross_items = [(i,s) for i,s in cross_items if i != num_layers-1]

    self_items.sort(key=lambda x: x[1])   # lowest first
    cross_items.sort(key=lambda x: x[1])  # lowest first

    pruned_self, pruned_cross = [], []

    for i, _ in self_items[:max(0, k_self)]:
        model.decoder.block[i].layer[0] = SkipSelfAttention()
        pruned_self.append(i)

    for i, _ in cross_items[:max(0, k_cross)]:
        if len(model.decoder.block[i].layer) > 1 and model.decoder.block[i].layer[1] is not None:
            model.decoder.block[i].layer[1] = SkipCrossAttention()
            pruned_cross.append(i)

    if verbose:
        print(f"Pruned decoder SELF-attention at layers:  {pruned_self}")
        print(f"Pruned decoder CROSS-attention at layers: {pruned_cross}")

    return pruned_self, pruned_cross

# =========================
# 3) Data & eval helpers (e-SNLI)
# =========================

def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if preds else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device, label_texts):
    # Disable cache to avoid KV/pos-bias drift after pruning
    model.config.use_cache = False
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=2,
            use_cache=False   # <— critical for stability after pruning
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# =========================
# 4) Train + collect attention J-dev, then prune attention-only
# =========================

def full_finetuning_collect_attn_jdev(train_loader, dev_loader, device, tokenizer, label_texts,
                                      jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6):
    """
    Stage 1: Full fine-tuning while periodically collecting FD-JVP J-dev
    for decoder self-attention and cross-attention sublayers.
    """
    print("=== Stage 1: Full FT & FD-JVP (Decoder Attention only) ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    model.config.use_cache = False  # <— avoid cache while collecting FD-JVP
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    hooks, self_bufs, cross_bufs = register_decoder_attention_hooks(model)
    self_sum, self_cnt   = defaultdict(float), defaultdict(int)
    cross_sum, cross_cnt = defaultdict(float), defaultdict(int)

    step = 0
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()
            with autocast():
                out = model(input_ids=batch["input_ids"].to(device),
                            attention_mask=batch["attention_mask"].to(device),
                            labels=batch["labels"].to(device))
                loss = out.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            if step % max(1, jvp_every) == 0:
                self_scores, cross_scores = compute_decoder_attention_jdev(
                    model, self_bufs, cross_bufs, eps=jvp_eps, k_probes=jvp_k
                )
                for i, v in self_scores.items():
                    if math.isfinite(v):
                        self_sum[i]  += v; self_cnt[i]  += 1
                for i, v in cross_scores.items():
                    if math.isfinite(v):
                        cross_sum[i] += v; cross_cnt[i] += 1

        epoch_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
        epoch_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
        print(f"[Epoch {epoch+1}] Decoder SELF-attn J-dev:  {epoch_self}")
        print(f"[Epoch {epoch+1}] Decoder CROSS-attn J-dev: {epoch_cross}")
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

    remove_hooks(hooks)
    final_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
    final_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
    return model, final_self, final_cross

def prune_attention_and_finetune(model, train_loader, dev_loader, device,
                                 self_scores, cross_scores, tokenizer, label_texts,
                                 k_self=2, k_cross=2, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Attention (self/cross) & FT ===")
    pruned_self, pruned_cross = prune_decoder_attention(
        model, self_scores, cross_scores, k_self=k_self, k_cross=k_cross,
        protect_first=True, protect_last=False, verbose=True
    )

    model.config.use_cache = False  # <— keep disabled after pruning
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(input_ids=batch["input_ids"].to(device),
                        attention_mask=batch["attention_mask"].to(device),
                        labels=batch["labels"].to(device))
            loss = out.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}")
    return model, pruned_self, pruned_cross

# =========================
# 5) Main
# =========================

def main():
    # e-SNLI file paths (adjust if different)
    data_files = {
        "train": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json",
        "validation": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json",
        "test": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json"
    }
    raw = load_dataset("json", data_files=data_files)
    tokenizer = T5TokenizerFast.from_pretrained("t5-base")
    label_texts = ["entailment", "neutral", "contradiction"]

    # Smaller subsets for demo; scale up as you wish
    train_ds = raw["train"].shuffle(seed=42).select(range(10000))
    dev_ds   = raw["validation"].shuffle(seed=42).select(range(2000))

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True, remove_columns=train_ds.column_names)
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True, remove_columns=dev_ds.column_names)

    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=16, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Stage 1: Train + collect attention-only J-dev
    model, self_scores, cross_scores = full_finetuning_collect_attn_jdev(
        train_loader, dev_loader, device, tokenizer, label_texts,
        jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6
    )

    # Stage 2: Prune attention ONLY (no FFN pruning) and FT
    # Example: prune 4 lowest self-attn, keep cross-attn intact
    model, pruned_self, pruned_cross = prune_attention_and_finetune(
        model, train_loader, dev_loader, device,
        self_scores, cross_scores, tokenizer, label_texts,
        k_self=num, k_cross=0, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# Autograd-JVP pruning — ONLY decoder attention layers (self & cross)
# NaN-hardened and mask/bias-safe

# ====== Colab Drive (no-op off Colab) ======
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# ====== Imports ======
import inspect
import warnings
import math
from collections import defaultdict
from contextlib import contextmanager

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# =====================================================================================
# 0) Helpers for Autograd-JVP and kwargs/mask handling
# =====================================================================================

def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            d[name] = args[i] if i < len(args) else sig.parameters[name].default
    d.pop("hidden_states", None)
    # Never keep position biases captured from some other step/length
    d.pop("position_bias", None)
    d.pop("encoder_decoder_position_bias", None)
    return d  # filtering happens at call-time

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # stable large negative instead of -inf to avoid NaNs in softmax
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _build_zero_mask(B, q_len, k_len, device, dtype):
    return torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self, B_fallback=1):
    """
    Return a (B,1,q_len,k_len) mask; if input mask is missing or wrong-sized,
    rebuild a safe default (causal for self, zeros for cross).
    """
    default_fn = _build_causal_mask if is_self else _build_zero_mask

    if mask is None:
        return default_fn(B_fallback, q_len, k_len, device, dtype)

    if mask.dim() == 2:
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B, 1, q_len, k_len).contiguous().to(device=device, dtype=dtype)

    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        # cannot slice up to desired size → rebuild
        return default_fn(B, q_len, k_len, device, dtype)

    # unknown shape → rebuild
    return default_fn(B_fallback, q_len, k_len, device, dtype)

@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention kernels so higher-order grads exist.
    Falls back to a no-op context if backend is unavailable (e.g., CPU).
    """
    try:
        with torch.backends.cuda.sdp_kernel(enable_flash=False,
                                            enable_mem_efficient=False,
                                            enable_math=True):
            yield
    except Exception:
        yield

def _make_attn_callable(module, argdict):
    """
    Build a pure function f(x) -> hidden_states for a single T5 sublayer (self/cross).
    We (re)construct masks and never reuse captured biases.
    """
    raw = dict(argdict) if argdict is not None else {}

    sig = inspect.signature(module.forward)
    params = set(sig.parameters.keys())

    is_cross = "key_value_states" in params

    # Pack constant kwargs up-front (except hidden_states which will be the input)
    const_kwargs = {}

    # KV states (for cross attention)
    if is_cross:
        kv = raw.get("key_value_states", raw.get("encoder_hidden_states", None))
        const_kwargs["key_value_states"] = None if kv is None else kv.float()

    # Deterministic flags
    if "use_cache" in params:
        const_kwargs["use_cache"] = False
    if "output_attentions" in params:
        const_kwargs["output_attentions"] = False
    if "return_dict" in params:
        const_kwargs["return_dict"] = False

    # We'll rebuild attention mask at call time (depends on x length)
    attn_mask_key = "attention_mask" if "attention_mask" in params else ("mask" if "mask" in params else None)
    captured_mask = raw.get(attn_mask_key, raw.get("mask", None)) if attn_mask_key is not None else None
    captured_kv   = const_kwargs.get("key_value_states", None)

    def _f(x):
        x = x.float()
        q_len = x.size(1)
        k_len = captured_kv.size(1) if (is_cross and captured_kv is not None) else q_len
        B     = x.size(0)

        kwargs = {"hidden_states": x}
        if is_cross:
            if captured_kv is None:
                # no encoder states available -> identity
                return x
            kwargs["key_value_states"] = captured_kv

        if attn_mask_key is not None:
            kwargs[attn_mask_key] = _ensure_mask_shape(
                captured_mask, q_len, k_len, x.device, x.dtype, is_self=(not is_cross), B_fallback=B
            )

        # cache_position / query_length
        if "cache_position" in params:
            kwargs["cache_position"] = torch.arange(q_len, dtype=torch.long, device=x.device)
        if "query_length" in params and "cache_position" not in params:
            kwargs["query_length"] = q_len

        # Never pass position biases from capture; force recompute
        kwargs.update(const_kwargs)
        kwargs = _filter_kwargs_for_module(module, kwargs)

        with torch.cuda.amp.autocast(enabled=False):
            out = module(**kwargs)

        hs = out[0] if isinstance(out, (tuple, list)) else out
        return hs

    return _f

def compute_decoder_attention_jdev_autograd(model, self_bufs, cross_bufs, k_probes=1, rademacher=True):
    """
    For each decoder attention sublayer ℓ, estimate:
      E_v ||(Jℓ - I)v||^2 using autograd JVP.
    Non-finite probes are skipped. Buffers are cleared.
    """
    from torch.autograd.functional import jvp as autograd_jvp

    self_scores, cross_scores = {}, {}

    # Self-attention J-dev
    for idx, buf in self_bufs.items():
        X, args = buf["X"], buf["args"]
        buf["X"], buf["args"] = None, None  # clear ASAP
        if X is None:
            continue
        mod = model.decoder.block[idx].layer[0]  # T5LayerSelfAttention

        was_training = mod.training
        mod.eval()
        try:
            fn = _make_attn_callable(mod, args)
            x0 = X.detach().requires_grad_(True).float()

            # Sanity forward
            try:
                with force_math_sdp():
                    _ = fn(x0)
            except Exception:
                continue

            acc, used = 0.0, 0
            with force_math_sdp():
                for _ in range(max(1, k_probes)):
                    if rademacher:
                        v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                    else:
                        v = torch.randn_like(x0)
                    try:
                        _, Jv = autograd_jvp(fn, (x0,), (v,), create_graph=False, strict=True)
                    except Exception:
                        continue
                    if Jv is None or not torch.isfinite(Jv).all():
                        continue
                    jd_vec = Jv - v
                    if not torch.isfinite(jd_vec).all():
                        continue
                    acc += float(jd_vec.pow(2).mean().item())
                    used += 1
            if used > 0:
                self_scores[idx] = acc / used
        finally:
            mod.train(was_training)

    # Cross-attention J-dev
    for idx, buf in cross_bufs.items():
        X, args = buf["X"], buf["args"]
        buf["X"], buf["args"] = None, None
        if X is None:
            continue
        if len(model.decoder.block[idx].layer) <= 1 or model.decoder.block[idx].layer[1] is None:
            continue
        mod = model.decoder.block[idx].layer[1]  # T5LayerCrossAttention

        was_training = mod.training
        mod.eval()
        try:
            fn = _make_attn_callable(mod, args)
            x0 = X.detach().requires_grad_(True).float()

            try:
                with force_math_sdp():
                    _ = fn(x0)
            except Exception:
                continue

            acc, used = 0.0, 0
            with force_math_sdp():
                for _ in range(max(1, k_probes)):
                    if rademacher:
                        v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                    else:
                        v = torch.randn_like(x0)
                    try:
                        _, Jv = autograd_jvp(fn, (x0,), (v,), create_graph=False, strict=True)
                    except Exception:
                        continue
                    if Jv is None or not torch.isfinite(Jv).all():
                        continue
                    jd_vec = Jv - v
                    if not torch.isfinite(jd_vec).all():
                        continue
                    acc += float(jd_vec.pow(2).mean().item())
                    used += 1
            if used > 0:
                cross_scores[idx] = acc / used
        finally:
            mod.train(was_training)

    return self_scores, cross_scores

# =====================================================================================
# 1) Hook Utilities — capture ONLY decoder attention sublayer inputs/args
# =====================================================================================

def register_decoder_attention_hooks(model):
    """
    Capture inputs/kwargs for decoder self-attention (layer[0]) and cross-attention (layer[1]).
    """
    dec_blocks = model.decoder.block
    self_bufs  = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    cross_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    hooks = []

    # Self-attention pre-hooks
    for i, block in enumerate(dec_blocks):
        sa = block.layer[0]
        def pre_hook_sa(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            self_bufs[idx]["X"]    = None if X is None else X.detach()
            self_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        hooks.append(sa.register_forward_pre_hook(pre_hook_sa, with_kwargs=True))

    # Cross-attention pre-hooks
    for i, block in enumerate(dec_blocks):
        if len(block.layer) > 1 and block.layer[1] is not None:
            ca = block.layer[1]
            def pre_hook_ca(module, args, kwargs, idx=i):
                X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
                cross_bufs[idx]["X"]    = None if X is None else X.detach()
                cross_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
            hooks.append(ca.register_forward_pre_hook(pre_hook_ca, with_kwargs=True))

    return hooks, self_bufs, cross_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# =====================================================================================
# 2) Attention-only pruning utilities
# =====================================================================================

class SkipSelfAttention(nn.Module):
    """
    Identity replacement for T5LayerSelfAttention.
    IMPORTANT: return (hidden_states, None, None) to avoid propagating stale position bias.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

class SkipCrossAttention(nn.Module):
    """
    Identity replacement for T5LayerCrossAttention.
    Also return None for position_bias to force recomputation downstream.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

def prune_decoder_attention(model, self_scores, cross_scores,
                            k_self=2, k_cross=2,
                            protect_first=True, protect_last=False, verbose=True):
    """
    Prune LOW J-dev attention sublayers (near-identity). FFNs untouched.
    """
    num_layers = len(model.decoder.block)
    self_items  = list(self_scores.items())
    cross_items = list(cross_scores.items())

    if protect_first:
        self_items  = [(i,s) for i,s in self_items  if i != 0]
        cross_items = [(i,s) for i,s in cross_items if i != 0]
    if protect_last:
        self_items  = [(i,s) for i,s in self_items  if i != num_layers-1]
        cross_items = [(i,s) for i,s in cross_items if i != num_layers-1]

    self_items.sort(key=lambda x: x[1])   # lowest first
    cross_items.sort(key=lambda x: x[1])  # lowest first

    pruned_self, pruned_cross = [], []

    for i, _ in self_items[:max(0, k_self)]:
        model.decoder.block[i].layer[0] = SkipSelfAttention()
        pruned_self.append(i)

    for i, _ in cross_items[:max(0, k_cross)]:
        if len(model.decoder.block[i].layer) > 1 and model.decoder.block[i].layer[1] is not None:
            model.decoder.block[i].layer[1] = SkipCrossAttention()
            pruned_cross.append(i)

    if verbose:
        print(f"Pruned decoder SELF-attention at layers:  {pruned_self}")
        print(f"Pruned decoder CROSS-attention at layers: {pruned_cross}")

    return pruned_self, pruned_cross

# =====================================================================================
# 3) Data & eval helpers (e-SNLI)
# =====================================================================================

def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if preds else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device, label_texts):
    # Disable cache to avoid KV/pos-bias drift after pruning
    model.config.use_cache = False
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=2,
            use_cache=False
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# =====================================================================================
# 4) Train + collect attention J-dev (Autograd JVP), then prune attention-only
# =====================================================================================

def full_finetuning_collect_attn_jdev(train_loader, dev_loader, device, tokenizer, label_texts,
                                      jvp_k=1, jvp_every=1, epochs=6, rademacher=True):
    """
    Stage 1: Full fine-tuning while periodically collecting Autograd-JVP J-dev
    for decoder self-attention and cross-attention sublayers.
    """
    print("=== Stage 1: Full FT & Autograd-JVP (Decoder Attention only) ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    model.config.use_cache = False  # avoid KV cache during JVP collection
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    hooks, self_bufs, cross_bufs = register_decoder_attention_hooks(model)
    self_sum, self_cnt   = defaultdict(float), defaultdict(int)
    cross_sum, cross_cnt = defaultdict(float), defaultdict(int)

    step = 0
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()
            with autocast():
                out = model(input_ids=batch["input_ids"].to(device),
                            attention_mask=batch["attention_mask"].to(device),
                            labels=batch["labels"].to(device))
                loss = out.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            if step % max(1, jvp_every) == 0:
                self_scores, cross_scores = compute_decoder_attention_jdev_autograd(
                    model, self_bufs, cross_bufs, k_probes=jvp_k, rademacher=rademacher
                )
                for i, v in self_scores.items():
                    if math.isfinite(v):
                        self_sum[i]  += v; self_cnt[i]  += 1
                for i, v in cross_scores.items():
                    if math.isfinite(v):
                        cross_sum[i] += v; cross_cnt[i] += 1

        epoch_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
        epoch_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
        print(f"[Epoch {epoch+1}] Decoder SELF-attn J-dev:  {epoch_self}")
        print(f"[Epoch {epoch+1}] Decoder CROSS-attn J-dev: {epoch_cross}")
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

    remove_hooks(hooks)
    final_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
    final_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
    return model, final_self, final_cross

def prune_attention_and_finetune(model, train_loader, dev_loader, device,
                                 self_scores, cross_scores, tokenizer, label_texts,
                                 k_self=2, k_cross=2, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Attention (self/cross) & FT ===")
    pruned_self, pruned_cross = prune_decoder_attention(
        model, self_scores, cross_scores, k_self=k_self, k_cross=k_cross,
        protect_first=True, protect_last=False, verbose=True
    )

    model.config.use_cache = False  # keep disabled after pruning
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(input_ids=batch["input_ids"].to(device),
                        attention_mask=batch["attention_mask"].to(device),
                        labels=batch["labels"].to(device))
            loss = out.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}")
    return model, pruned_self, pruned_cross

# =====================================================================================
# 5) Main
# =====================================================================================

def main():
    # e-SNLI file paths (adjust if different)
    data_files = {
        "train": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json",
        "validation": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json",
        "test": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json"
    }
    raw = load_dataset("json", data_files=data_files)
    tokenizer = T5TokenizerFast.from_pretrained("t5-base")
    label_texts = ["entailment", "neutral", "contradiction"]

    # Smaller subsets for demo; scale up as you wish
    train_ds = raw["train"].shuffle(seed=42).select(range(10000))
    dev_ds   = raw["validation"].shuffle(seed=42).select(range(2000))

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True, remove_columns=train_ds.column_names)
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True, remove_columns=dev_ds.column_names)

    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=16, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Stage 1: Train + collect attention-only J-dev (Autograd JVP)
    model, self_scores, cross_scores = full_finetuning_collect_attn_jdev(
        train_loader, dev_loader, device, tokenizer, label_texts,
        jvp_k=2, jvp_every=1, epochs=6, rademacher=True
    )

    # Stage 2: Prune attention ONLY (no FFN pruning) and FT
    # Example: prune 4 lowest self-attn, keep cross-attn intact
    model, pruned_self, pruned_cross = prune_attention_and_finetune(
        model, train_loader, dev_loader, device,
        self_scores, cross_scores, tokenizer, label_texts,
        k_self=num, k_cross=0, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# Prune the decoder (ATTENTION-ONLY) with Finite-Difference JVP — CQA

# ====== Colab Drive (no-op off Colab) ======
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# ====== Imports ======
import inspect
import warnings
import math
from collections import defaultdict

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# =========================
# 0) CQA data
# =========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json"
}
dataset = load_dataset("json", data_files=data_files)

def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):
    if use_cot and 'abstractive_explanation' in batch:
        inputs = [
            f"question: {q} choices: {', '.join(choices)} rationale: {exp}"
            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])
        ]
    else:
        inputs = [
            f"question: {q} choices: {', '.join(choices)}"
            for q, choices in zip(batch['question'], batch['choices'])
        ]
    targets = [str(ans).strip() for ans in batch['answer']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(targets, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
USE_COT = False

train = dataset["train"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),
                             batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),
                             batched=True, remove_columns=dataset["test"].column_names)

collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=32, shuffle=True,  collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=32, shuffle=False, collate_fn=collator)

# =========================
# 1) Eval helpers
# =========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    # disable cache to avoid KV/pos-bias drift after pruning
    model.config.use_cache = False
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4,
            use_cache=False,          # critical after pruning
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# =========================
# 2) FD-JVP helpers (robust)
# =========================
def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            if i < len(args):
                d[name] = args[i]
            else:
                d[name] = sig.parameters[name].default
    d.pop("hidden_states", None)
    return d  # filter at call time

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # use stable large negative (not -inf) to avoid NaNs in softmax
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self):
    """
    Return a (B,1,q_len,k_len) mask; if input mask is missing or wrong-sized,
    rebuild a safe default (causal for self, zeros for cross).
    """
    if mask is None:
        B = 1
        return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
               else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)
    if mask.dim() == 2:
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B,1,q_len,k_len).contiguous().to(device=device, dtype=dtype)
    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
               else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)
    B = 1
    return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
           else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)

@torch.no_grad()
def _t5sublayer_forward_only(module, X, argdict):
    """
    Deterministic forward for a single T5 sublayer (Self/Cross Attention).
    Returns hidden_states (tensor). fp32, autocast disabled. Provides safe masks
    and cache_position/query_length when missing.
    """
    if X is None:
        return None
    was_training = module.training
    module.eval()
    Xf = X.float()
    try:
        raw = dict(argdict) if argdict is not None else {}
        raw.pop("position_bias", None)  # force recompute

        sig = inspect.signature(module.forward)
        params = set(sig.parameters.keys())

        kwargs = {"hidden_states": Xf}

        # Cross-attention?
        is_cross = "key_value_states" in params
        kv = None
        if is_cross:
            kv = raw.get("key_value_states", raw.get("encoder_hidden_states", None))
            if kv is None:
                return None
            kwargs["key_value_states"] = kv.float()

        q_len = Xf.size(1)
        k_len = kv.size(1) if (is_cross and kv is not None) else q_len

        mask_key = "attention_mask" if "attention_mask" in params else ("mask" if "mask" in params else None)
        src_mask = raw.get(mask_key, raw.get("mask", None)) if mask_key is not None else None
        kwargs[mask_key] = _ensure_mask_shape(src_mask, q_len, k_len, Xf.device, Xf.dtype, is_self=not is_cross)

        # cache_position / query_length
        if "cache_position" in params:
            cp = raw.get("cache_position", None)
            if cp is None:
                cp = torch.arange(q_len, dtype=torch.long, device=Xf.device)
            kwargs["cache_position"] = cp
        if "query_length" in params and "cache_position" not in params:
            kwargs["query_length"] = q_len

        if "use_cache" in params:
            kwargs["use_cache"] = False
        if "output_attentions" in params:
            kwargs["output_attentions"] = False

        kwargs = _filter_kwargs_for_module(module, kwargs)

        with torch.cuda.amp.autocast(enabled=False):
            out = module(**kwargs)

        hs = out[0] if isinstance(out, (tuple, list)) else out
        if hs is None or not torch.isfinite(hs).all():
            return None
        return hs
    finally:
        module.train(was_training)

@torch.no_grad()
def compute_decoder_attention_jdev(model, self_bufs, cross_bufs, eps=1e-3, k_probes=1):
    """
    For each decoder attention sublayer ℓ, estimate:
      E_v ||(Jℓ - I)v||^2 ≈ E_v || (f(x + eps v) - f(x))/eps - v ||^2
    Non-finite probes are skipped.
    """
    self_scores, cross_scores = {}, {}

    # Self-attention J-dev
    for idx, buf in self_bufs.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue
        mod = model.decoder.block[idx].layer[0]
        X = X.float()
        y0 = _t5sublayer_forward_only(mod, X, args)
        if y0 is None:
            buf["X"], buf["args"] = None, None
            continue
        acc, used = 0.0, 0
        for _ in range(max(1, k_probes)):
            v = torch.randn_like(X)
            y_eps = _t5sublayer_forward_only(mod, X + eps * v, args)
            if y_eps is None:
                continue
            jd_vec = (y_eps - y0) / eps - v
            if not torch.isfinite(jd_vec).all():
                continue
            acc += float(jd_vec.pow(2).mean().item())
            used += 1
        if used > 0:
            self_scores[idx] = acc / used
        buf["X"], buf["args"] = None, None

    # Cross-attention J-dev
    for idx, buf in cross_bufs.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue
        if len(model.decoder.block[idx].layer) <= 1 or model.decoder.block[idx].layer[1] is None:
            continue
        mod = model.decoder.block[idx].layer[1]
        X = X.float()
        y0 = _t5sublayer_forward_only(mod, X, args)
        if y0 is None:
            buf["X"], buf["args"] = None, None
            continue
        acc, used = 0.0, 0
        for _ in range(max(1, k_probes)):
            v = torch.randn_like(X)
            y_eps = _t5sublayer_forward_only(mod, X + eps * v, args)
            if y_eps is None:
                continue
            jd_vec = (y_eps - y0) / eps - v
            if not torch.isfinite(jd_vec).all():
                continue
            acc += float(jd_vec.pow(2).mean().item())
            used += 1
        if used > 0:
            cross_scores[idx] = acc / used
        buf["X"], buf["args"] = None, None

    return self_scores, cross_scores

# =========================
# 3) Hooks to capture ONLY decoder attention inputs
# =========================
def register_decoder_attention_hooks(model):
    dec_blocks = model.decoder.block
    self_bufs  = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    cross_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    hooks = []

    for i, block in enumerate(dec_blocks):
        sa = block.layer[0]
        def pre_hook_sa(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            self_bufs[idx]["X"]    = None if X is None else X.detach()
            self_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        hooks.append(sa.register_forward_pre_hook(pre_hook_sa, with_kwargs=True))

    for i, block in enumerate(dec_blocks):
        if len(block.layer) > 1 and block.layer[1] is not None:
            ca = block.layer[1]
            def pre_hook_ca(module, args, kwargs, idx=i):
                X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
                cross_bufs[idx]["X"]    = None if X is None else X.detach()
                cross_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
            hooks.append(ca.register_forward_pre_hook(pre_hook_ca, with_kwargs=True))

    return hooks, self_bufs, cross_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# =========================
# 4) Attention-only pruning utilities
# =========================
class SkipSelfAttention(nn.Module):
    """
    Identity replacement for T5LayerSelfAttention.
    IMPORTANT: return (hidden_states, None, None) so no stale position_bias is reused.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

class SkipCrossAttention(nn.Module):
    """
    Identity replacement for T5LayerCrossAttention.
    Also return None for position_bias to force recomputation downstream.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

def prune_decoder_attention(model, self_scores, cross_scores,
                            k_self=2, k_cross=2,
                            protect_first=True, protect_last=False, verbose=True):
    num_layers = len(model.decoder.block)
    self_items  = list(self_scores.items())
    cross_items = list(cross_scores.items())

    if protect_first:
        self_items  = [(i,s) for i,s in self_items  if i != 0]
        cross_items = [(i,s) for i,s in cross_items if i != 0]
    if protect_last:
        self_items  = [(i,s) for i,s in self_items  if i != num_layers-1]
        cross_items = [(i,s) for i,s in cross_items if i != num_layers-1]

    self_items.sort(key=lambda x: x[1])   # lowest J-dev first
    cross_items.sort(key=lambda x: x[1])

    pruned_self, pruned_cross = [], []

    for i, _ in self_items[:max(0, k_self)]:
        model.decoder.block[i].layer[0] = SkipSelfAttention()
        pruned_self.append(i)

    for i, _ in cross_items[:max(0, k_cross)]:
        if len(model.decoder.block[i].layer) > 1 and model.decoder.block[i].layer[1] is not None:
            model.decoder.block[i].layer[1] = SkipCrossAttention()
            pruned_cross.append(i)

    if verbose:
        print(f"Pruned decoder SELF-attention at layers:  {pruned_self}")
        print(f"Pruned decoder CROSS-attention at layers: {pruned_cross}")

    return pruned_self, pruned_cross

# =========================
# 5) Train → collect J-dev → prune attention → FT
# =========================
def full_finetuning_collect_attn_jdev(train_loader, dev_loader, device, tokenizer,
                                      jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6):
    print("=== Stage 1: Full FT & FD-JVP (Decoder Attention only) ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    model.config.use_cache = False  # keep cache off while we collect FD-JVP
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    hooks, self_bufs, cross_bufs = register_decoder_attention_hooks(model)
    self_sum, self_cnt   = defaultdict(float), defaultdict(int)
    cross_sum, cross_cnt = defaultdict(float), defaultdict(int)

    step = 0
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()
            with autocast():
                out = model(input_ids=batch["input_ids"].to(device),
                            attention_mask=batch["attention_mask"].to(device),
                            labels=batch["labels"].to(device))
                loss = out.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            if step % max(1, jvp_every) == 0:
                self_scores, cross_scores = compute_decoder_attention_jdev(
                    model, self_bufs, cross_bufs, eps=jvp_eps, k_probes=jvp_k
                )
                for i, v in self_scores.items():
                    if math.isfinite(v):
                        self_sum[i]  += v; self_cnt[i]  += 1
                for i, v in cross_scores.items():
                    if math.isfinite(v):
                        cross_sum[i] += v; cross_cnt[i] += 1

        epoch_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
        epoch_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
        print(f"[Epoch {epoch+1}] Decoder SELF-attn J-dev:  {epoch_self}")
        print(f"[Epoch {epoch+1}] Decoder CROSS-attn J-dev: {epoch_cross}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] CQA Dev Acc: {acc:.4f}")

    remove_hooks(hooks)
    final_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
    final_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
    return model, final_self, final_cross

def prune_attention_and_finetune(model, train_loader, dev_loader, device,
                                 self_scores, cross_scores, tokenizer,
                                 k_self=2, k_cross=2, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Attention (self/cross) & Fine-tune ===")
    pruned_self, pruned_cross = prune_decoder_attention(
        model, self_scores, cross_scores,
        k_self=k_self, k_cross=k_cross,
        protect_first=True, protect_last=False, verbose=True
    )

    # keep cache disabled after structural changes
    model.config.use_cache = False
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(input_ids=batch["input_ids"].to(device),
                        attention_mask=batch["attention_mask"].to(device),
                        labels=batch["labels"].to(device))
            loss = out.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}")
    return model, pruned_self, pruned_cross

# =========================
# 6) Main
# =========================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Stage 1: Train + collect attention-only J-dev
    model, self_scores, cross_scores = full_finetuning_collect_attn_jdev(
        train_loader, dev_loader, device, tokenizer,
        jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6
    )

    # Stage 2: Prune attention ONLY and fine-tune
    # Tip: set k_cross=0 if you want to keep all cross-attention.
    model, pruned_self, pruned_cross = prune_attention_and_finetune(
        model, train_loader, dev_loader, device,
        self_scores, cross_scores, tokenizer,
        k_self=num, k_cross=0, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# Autograd-JVP pruning — ONLY decoder attention layers (self & cross) on CQA
# NaN-hardened and mask/bias-safe

# ====== Colab Drive (no-op off Colab) ======
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# ====== Imports ======
import inspect
import warnings
import math
from collections import defaultdict
from contextlib import contextmanager

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# =========================
# 0) CQA data
# =========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json"
}
dataset = load_dataset("json", data_files=data_files)

def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):
    if use_cot and 'abstractive_explanation' in batch:
        inputs = [
            f"question: {q} choices: {', '.join(choices)} rationale: {exp}"
            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])
        ]
    else:
        inputs = [
            f"question: {q} choices: {', '.join(choices)}"
            for q, choices in zip(batch['question'], batch['choices'])
        ]
    targets = [str(ans).strip() for ans in batch['answer']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(targets, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
USE_COT = False

train = dataset["train"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),
                             batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),
                             batched=True, remove_columns=dataset["test"].column_names)

# Use label_pad_token_id=-100 so CE ignores padding
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128, label_pad_token_id=-100)
train_loader = DataLoader(train, batch_size=32, shuffle=True,  collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=32, shuffle=False, collate_fn=collator)

# =========================
# 1) Eval helpers
# =========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    # disable cache to avoid KV/pos-bias drift after pruning
    model.config.use_cache = False
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4,
            use_cache=False,          # critical after pruning
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# =========================
# 2) Autograd-JVP helpers (robust) + mask/bias utilities
# =========================
def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            if i < len(args):
                d[name] = args[i]
            else:
                d[name] = sig.parameters[name].default
    # we will set hidden_states; never reuse cached biases
    d.pop("hidden_states", None)
    d.pop("position_bias", None)
    d.pop("encoder_decoder_position_bias", None)
    return d

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # stable large negative (not -inf) to avoid NaNs in softmax
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self):
    """
    Return a (B,1,q_len,k_len) mask; if input mask is missing or wrong-sized,
    rebuild a safe default (causal for self, zeros for cross).
    """
    if mask is None:
        B = 1
        return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
               else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)
    if mask.dim() == 2:  # (B, K)
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B,1,q_len,k_len).contiguous().to(device=device, dtype=dtype)
    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
               else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)
    B = 1
    return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
           else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)

@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention kernels so higher-order grads exist.
    Falls back to a no-op context if backend is unavailable (e.g., CPU).
    """
    try:
        with torch.backends.cuda.sdp_kernel(enable_flash=False,
                                            enable_mem_efficient=False,
                                            enable_math=True):
            yield
    except Exception:
        yield

def _make_attn_callable(module, argdict):
    """
    Build pure fn f(x)->hs for a single T5 sublayer (self/cross).
    Rebuild masks and never reuse captured position biases.
    """
    raw = dict(argdict) if argdict is not None else {}

    sig = inspect.signature(module.forward)
    params = set(sig.parameters.keys())
    is_cross = "key_value_states" in params

    const_kwargs = {}

    # KV states for cross attention
    if is_cross:
        kv = raw.get("key_value_states", raw.get("encoder_hidden_states", None))
        const_kwargs["key_value_states"] = None if kv is None else kv.float()

    # deterministic flags
    if "use_cache" in params:
        const_kwargs["use_cache"] = False
    if "output_attentions" in params:
        const_kwargs["output_attentions"] = False
    if "return_dict" in params:
        const_kwargs["return_dict"] = False

    # mask source (we'll reshape at call time)
    mask_key = "attention_mask" if "attention_mask" in params else ("mask" if "mask" in params else None)
    captured_mask = raw.get(mask_key, raw.get("mask", None)) if mask_key is not None else None
    captured_kv = const_kwargs.get("key_value_states", None)

    def _f(x):
        x = x.float()
        q_len = x.size(1)
        k_len = captured_kv.size(1) if (is_cross and captured_kv is not None) else q_len
        B     = x.size(0)

        kwargs = {"hidden_states": x}
        if is_cross:
            if captured_kv is None:
                return x  # identity if no encoder states present
            kwargs["key_value_states"] = captured_kv

        if mask_key is not None:
            kwargs[mask_key] = _ensure_mask_shape(
                captured_mask, q_len, k_len, x.device, x.dtype, is_self=(not is_cross)
            )

        # cache_position / query_length
        if "cache_position" in params:
            kwargs["cache_position"] = torch.arange(q_len, dtype=torch.long, device=x.device)
        if "query_length" in params and "cache_position" not in params:
            kwargs["query_length"] = q_len

        # do NOT pass any position_bias
        kwargs.update(const_kwargs)
        kwargs = _filter_kwargs_for_module(module, kwargs)

        with torch.cuda.amp.autocast(enabled=False):
            out = module(**kwargs)

        hs = out[0] if isinstance(out, (tuple, list)) else out
        return hs

    return _f

def compute_decoder_attention_jdev_autograd(model, self_bufs, cross_bufs, k_probes=1, rademacher=True):
    """
    For each decoder attention sublayer ℓ, estimate:
      E_v ||(Jℓ - I)v||^2 using autograd JVP.
    Non-finite probes are skipped. Buffers are cleared.
    """
    from torch.autograd.functional import jvp as autograd_jvp

    self_scores, cross_scores = {}, {}

    # Self-attention J-dev
    for idx, buf in self_bufs.items():
        X, args = buf["X"], buf["args"]
        buf["X"], buf["args"] = None, None  # clear ASAP
        if X is None:
            continue
        mod = model.decoder.block[idx].layer[0]  # T5LayerSelfAttention

        was_training = mod.training
        mod.eval()
        try:
            fn = _make_attn_callable(mod, args)
            x0 = X.detach().requires_grad_(True).float()

            # Sanity forward
            try:
                with force_math_sdp():
                    _ = fn(x0)
            except Exception:
                continue

            acc, used = 0.0, 0
            with force_math_sdp():
                for _ in range(max(1, k_probes)):
                    if rademacher:
                        v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                    else:
                        v = torch.randn_like(x0)
                    try:
                        _, Jv = autograd_jvp(fn, (x0,), (v,), create_graph=False, strict=True)
                    except Exception:
                        continue
                    if Jv is None or not torch.isfinite(Jv).all():
                        continue
                    jd_vec = Jv - v
                    if not torch.isfinite(jd_vec).all():
                        continue
                    acc += float(jd_vec.pow(2).mean().item())
                    used += 1
            if used > 0:
                self_scores[idx] = acc / used
        finally:
            mod.train(was_training)

    # Cross-attention J-dev
    for idx, buf in cross_bufs.items():
        X, args = buf["X"], buf["args"]
        buf["X"], buf["args"] = None, None
        if X is None:
            continue
        if len(model.decoder.block[idx].layer) <= 1 or model.decoder.block[idx].layer[1] is None:
            continue
        mod = model.decoder.block[idx].layer[1]  # T5LayerCrossAttention

        was_training = mod.training
        mod.eval()
        try:
            fn = _make_attn_callable(mod, args)
            x0 = X.detach().requires_grad_(True).float()

            try:
                with force_math_sdp():
                    _ = fn(x0)
            except Exception:
                continue

            acc, used = 0.0, 0
            with force_math_sdp():
                for _ in range(max(1, k_probes)):
                    if rademacher:
                        v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                    else:
                        v = torch.randn_like(x0)
                    try:
                        _, Jv = autograd_jvp(fn, (x0,), (v,), create_graph=False, strict=True)
                    except Exception:
                        continue
                    if Jv is None or not torch.isfinite(Jv).all():
                        continue
                    jd_vec = Jv - v
                    if not torch.isfinite(jd_vec).all():
                        continue
                    acc += float(jd_vec.pow(2).mean().item())
                    used += 1
            if used > 0:
                cross_scores[idx] = acc / used
        finally:
            mod.train(was_training)

    return self_scores, cross_scores

# =========================
# 3) Hooks to capture ONLY decoder attention inputs
# =========================
def register_decoder_attention_hooks(model):
    dec_blocks = model.decoder.block
    self_bufs  = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    cross_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    hooks = []

    for i, block in enumerate(dec_blocks):
        sa = block.layer[0]
        def pre_hook_sa(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            self_bufs[idx]["X"]    = None if X is None else X.detach()
            self_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        hooks.append(sa.register_forward_pre_hook(pre_hook_sa, with_kwargs=True))

    for i, block in enumerate(dec_blocks):
        if len(block.layer) > 1 and block.layer[1] is not None:
            ca = block.layer[1]
            def pre_hook_ca(module, args, kwargs, idx=i):
                X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
                cross_bufs[idx]["X"]    = None if X is None else X.detach()
                cross_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
            hooks.append(ca.register_forward_pre_hook(pre_hook_ca, with_kwargs=True))

    return hooks, self_bufs, cross_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# =========================
# 4) Attention-only pruning utilities
# =========================
class SkipSelfAttention(nn.Module):
    """
    Identity replacement for T5LayerSelfAttention.
    Return (hidden_states, None, None) so no stale position_bias is reused.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

class SkipCrossAttention(nn.Module):
    """
    Identity replacement for T5LayerCrossAttention.
    Also return None for position_bias to force recomputation downstream.
    """
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

def prune_decoder_attention(model, self_scores, cross_scores,
                            k_self=2, k_cross=2,
                            protect_first=True, protect_last=False, verbose=True):
    num_layers = len(model.decoder.block)
    self_items  = list(self_scores.items())
    cross_items = list(cross_scores.items())

    if protect_first:
        self_items  = [(i,s) for i,s in self_items  if i != 0]
        cross_items = [(i,s) for i,s in cross_items if i != 0]
    if protect_last:
        self_items  = [(i,s) for i,s in self_items  if i != num_layers-1]
        cross_items = [(i,s) for i,s in cross_items if i != num_layers-1]

    self_items.sort(key=lambda x: x[1])   # lowest J-dev first
    cross_items.sort(key=lambda x: x[1])

    pruned_self, pruned_cross = [], []

    for i, _ in self_items[:max(0, k_self)]:
        model.decoder.block[i].layer[0] = SkipSelfAttention()
        pruned_self.append(i)

    for i, _ in cross_items[:max(0, k_cross)]:
        if len(model.decoder.block[i].layer) > 1 and model.decoder.block[i].layer[1] is not None:
            model.decoder.block[i].layer[1] = SkipCrossAttention()
            pruned_cross.append(i)

    if verbose:
        print(f"Pruned decoder SELF-attention at layers:  {pruned_self}")
        print(f"Pruned decoder CROSS-attention at layers: {pruned_cross}")

    return pruned_self, pruned_cross

# =========================
# 5) Train → collect Autograd-JVP → prune attention → FT
# =========================
def full_finetuning_collect_attn_jdev(train_loader, dev_loader, device, tokenizer,
                                      jvp_k=1, jvp_every=1, epochs=6):
    print("=== Stage 1: Full FT & Autograd-JVP (Decoder Attention only) ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    model.config.use_cache = False  # keep cache off while we collect JVP
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    hooks, self_bufs, cross_bufs = register_decoder_attention_hooks(model)
    self_sum, self_cnt   = defaultdict(float), defaultdict(int)
    cross_sum, cross_cnt = defaultdict(float), defaultdict(int)

    step = 0
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()
            with autocast():
                out = model(input_ids=batch["input_ids"].to(device),
                            attention_mask=batch["attention_mask"].to(device),
                            labels=batch["labels"].to(device))
                loss = out.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            if step % max(1, jvp_every) == 0:
                self_scores, cross_scores = compute_decoder_attention_jdev_autograd(
                    model, self_bufs, cross_bufs, k_probes=jvp_k
                )
                for i, v in self_scores.items():
                    if math.isfinite(v):
                        self_sum[i]  += v; self_cnt[i]  += 1
                for i, v in cross_scores.items():
                    if math.isfinite(v):
                        cross_sum[i] += v; cross_cnt[i] += 1

        epoch_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
        epoch_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
        print(f"[Epoch {epoch+1}] Decoder SELF-attn J-dev:  {epoch_self}")
        print(f"[Epoch {epoch+1}] Decoder CROSS-attn J-dev: {epoch_cross}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] CQA Dev Acc: {acc:.4f}")

    remove_hooks(hooks)
    final_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
    final_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
    return model, final_self, final_cross

def prune_attention_and_finetune(model, train_loader, dev_loader, device,
                                 self_scores, cross_scores, tokenizer,
                                 k_self=2, k_cross=2, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Attention (self/cross) & Fine-tune ===")
    pruned_self, pruned_cross = prune_decoder_attention(
        model, self_scores, cross_scores,
        k_self=k_self, k_cross=k_cross,
        protect_first=True, protect_last=False, verbose=True
    )

    # keep cache disabled after structural changes
    model.config.use_cache = False
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(input_ids=batch["input_ids"].to(device),
                        attention_mask=batch["attention_mask"].to(device),
                        labels=batch["labels"].to(device))
            loss = out.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}")
    return model, pruned_self, pruned_cross

# =========================
# 6) Main
# =========================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Stage 1: Train + collect attention-only J-dev (Autograd JVP)
    model, self_scores, cross_scores = full_finetuning_collect_attn_jdev(
        train_loader, dev_loader, device, tokenizer,
        jvp_k=1, jvp_every=1, epochs=6
    )

    # Stage 2: Prune attention ONLY and fine-tune
    # Tip: set k_cross=0 if you want to keep all cross-attention.
    model, pruned_self, pruned_cross = prune_attention_and_finetune(
        model, train_loader, dev_loader, device,
        self_scores, cross_scores, tokenizer,
        k_self=num, k_cross=0, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# Only prune decoder — FD-JVP (NaN-safe, bias-safe)

# --- Mount Google Drive if using Colab ---
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# --- Standard Imports ---
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, Adafactor
)
from torch.cuda.amp import autocast
from collections import defaultdict
import warnings
import math
import random
import numpy as np
import inspect

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ------------- Repro -------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(1234)

# --- 1. Load ANLI1 Dataset ---
data_files = {
    "train":      "/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json",
    "validation": "/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json",
    "test":       "/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Function ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    label_list = ["entailment", "neutral", "contradiction"]

    labels_str = []
    for x in batch['label']:
        sx = str(x).strip().lower()
        if sx.isdigit() and int(sx) < 3:
            labels_str.append(label_list[int(sx)])
        else:
            labels_str.append(sx)

    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(text_target=labels_str, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

# Tokenizer
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

# Map datasets
train = dataset["train"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["train"].column_names
)
dev = dataset["validation"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["validation"].column_names
)

# --- Load model before creating the collator (so collator can mask label pads -> -100) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
# Avoid kv-cache so biases are always recomputed (also matches SkipBlock behavior)
model.config.use_cache = False

# Collator that converts pad tokens in labels to -100
collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

# =========================
# 3) FD-JVP helpers (NaN-safe) — capture + deterministic replay for T5Block
# =========================

def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    """Normalize to kwargs for re-calling module.forward."""
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            d[name] = args[i] if i < len(args) else sig.parameters[name].default
    d.pop("hidden_states", None)  # we will set it explicitly
    return d

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # stable large negative (not -inf) to avoid NaNs in softmax
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _build_zero_mask(B, q_len, k_len, device, dtype):
    return torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self, B_fallback=1):
    """
    Produce a valid (B,1,q_len,k_len) mask; rebuild default if missing/invalid.
    Self-attn -> causal mask; Cross-attn -> zero mask.
    """
    default_fn = _build_causal_mask if is_self else _build_zero_mask

    if mask is None:
        return default_fn(B_fallback, q_len, k_len, device, dtype)

    if mask.dim() == 2:  # (B, K)
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B, 1, q_len, k_len).contiguous().to(device=device, dtype=dtype)

    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        return default_fn(B, q_len, k_len, device, dtype)

    return default_fn(B_fallback, q_len, k_len, device, dtype)

@torch.no_grad()
def _t5block_forward_only(block, X, argdict):
    """
    Deterministic single T5Block forward in fp32 (autocast disabled).
    Provides safe masks & cache_position/query_length when missing.
    Returns only hidden_states tensor.
    """
    if X is None:
        return None
    was_training = block.training
    block.eval()
    Xf = X.float()
    try:
        raw = dict(argdict) if argdict is not None else {}

        # Never reuse captured biases (length-mismatch risk)
        raw.pop("position_bias", None)
        raw.pop("encoder_decoder_position_bias", None)

        kwargs = {"hidden_states": Xf}

        q_len = Xf.size(1)
        B = Xf.size(0)

        # Self-attention mask (called "attention_mask" at block level)
        attn_mask = raw.get("attention_mask", raw.get("mask", None))
        kwargs["attention_mask"] = _ensure_mask_shape(
            attn_mask, q_len, q_len, Xf.device, Xf.dtype, is_self=True, B_fallback=B
        )

        # Cross-attention (if encoder states available)
        enc_states = raw.get("encoder_hidden_states", None)
        if enc_states is not None:
            enc_states = enc_states.float()
            kwargs["encoder_hidden_states"] = enc_states
            k_len = enc_states.size(1)
            enc_mask = raw.get("encoder_attention_mask", None)
            kwargs["encoder_attention_mask"] = _ensure_mask_shape(
                enc_mask, q_len, k_len, Xf.device, Xf.dtype, is_self=False, B_fallback=B
            )

        # cache_position / query_length if present in this HF version
        params = set(inspect.signature(block.forward).parameters.keys())
        if "cache_position" in params:
            cp = raw.get("cache_position", None)
            if cp is None:
                cp = torch.arange(q_len, dtype=torch.long, device=Xf.device)
            kwargs["cache_position"] = cp
        if "query_length" in params:
            kwargs["query_length"] = q_len

        # Deterministic flags
        if "use_cache" in params:
            kwargs["use_cache"] = False
        if "output_attentions" in params:
            kwargs["output_attentions"] = False
        if "return_dict" in params:
            kwargs["return_dict"] = False

        kwargs = _filter_kwargs_for_module(block, kwargs)

        with torch.cuda.amp.autocast(enabled=False):
            out = block(**kwargs)

        hs = out[0] if isinstance(out, (tuple, list)) else out
        return hs if hs is None or torch.isfinite(hs).all() else None
    finally:
        block.train(was_training)

def register_decoder_block_jvp_hooks(model):
    """
    Capture per-decoder-block incoming hidden_states + kwargs with forward_pre_hook.
    """
    dec_blocks = model.decoder.block
    dec_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    dec_hooks = []
    for i, block in enumerate(dec_blocks):
        def pre_hook(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            dec_bufs[idx]["X"]    = None if X is None else X.detach()
            dec_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        dec_hooks.append(block.register_forward_pre_hook(pre_hook, with_kwargs=True))
    return dec_hooks, dec_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def compute_decoder_block_jdev(model, dec_bufs, eps=1e-3, k_probes=1):
    """
    For each decoder block ℓ, estimate E_v ||(Jℓ - I)v||^2 via finite differences:
      (J - I)v ≈ (f(x + eps v) - f(x))/eps - v
    NaN-safe: skips any non-finite forward/probe results. Clears buffers.
    """
    scores = {}
    for idx, buf in dec_bufs.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue
        block = model.decoder.block[idx]
        X = X.float()

        y0 = _t5block_forward_only(block, X, args)
        if y0 is None:
            buf["X"], buf["args"] = None, None
            continue

        acc, used = 0.0, 0
        for _ in range(max(1, k_probes)):
            v = torch.randn_like(X)
            y_eps = _t5block_forward_only(block, X + eps * v, args)
            if y_eps is None:
                continue
            jd_vec = (y_eps - y0) / eps - v
            if not torch.isfinite(jd_vec).all():
                continue
            acc += float(jd_vec.pow(2).mean().item())
            used += 1

        if used > 0:
            scores[idx] = acc / used

        buf["X"], buf["args"] = None, None
    return scores

# =========================
# 4) Pruning Utilities (decoder blocks)
# =========================

class SkipBlock(nn.Module):
    """
    Identity replacement for a T5 decoder block.
    IMPORTANT: return None for biases so later layers recompute them
               with the correct (Q,K) sizes.
    Keeps tuple shape:
    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,
     position_bias, encoder_decoder_position_bias)
    """
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=False,
        cache_position=None,
        **kwargs,
    ):
        return (hidden_states, None, None, None, None, None)

def prune_decoder_blocks_by_jdev(blocks, jdev_scores, num_prune=4, protect_first=True, protect_last=False, verbose=True):
    """
    Prune LOW J-dev (near-identity) decoder blocks.
    """
    if not jdev_scores:
        if verbose:
            print("No J-dev scores available; skipping pruning.")
        return []
    items = list(jdev_scores.items())
    if protect_first:
        items = [(i, s) for i, s in items if i != 0]
    if protect_last:
        items = [(i, s) for i, s in items if i != len(blocks) - 1]
    items.sort(key=lambda x: x[1])  # lowest first

    prune_idxs = [i for i, _ in items[:max(0, num_prune)]]
    for i in prune_idxs:
        blocks[i] = SkipBlock()
    if verbose:
        print(f"Pruned decoder blocks (lowest J-dev): {prune_idxs}")
    return prune_idxs

# =========================
# 5) Training / Eval
# =========================

CANON = {
    "entailment": "entailment",
    "entailed": "entailment",
    "neutral": "neutral",
    "contradiction": "contradiction",
    "contradict": "contradiction",
    "contradictory": "contradiction",
    "contradicted": "contradiction",
}

def canonicalize_label(s: str):
    s = (s or "").strip().lower()
    first = s.split()[0] if s else s
    return CANON.get(first, first)

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if canonicalize_label(p) == canonicalize_label(l):
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # decode gold labels to text
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        # generate predictions; keep cache off for consistency (biases recompute each step)
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4,
            use_cache=False,
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])

    return compute_accuracy(preds, refs)

def build_optimizer(model):
    # Stable default for T5
    return Adafactor(model.parameters(), relative_step=True, scale_parameter=True, warmup_init=True)

def full_finetuning_with_jdev(model, train_loader, dev_loader, device, tokenizer,
                              jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6):
    """
    Stage 1: Full FT while periodically collecting FD-JVP J-dev for decoder blocks.
    """
    print("=== Stage 1: Full FT + FD-JVP (Decoder Blocks) ===")
    opt = build_optimizer(model)
    dec_hooks, dec_bufs = register_decoder_block_jvp_hooks(model)

    dec_sum, dec_cnt = defaultdict(float), defaultdict(int)
    step = 0

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()

            # Disable AMP for numerical stability
            with autocast(enabled=False):
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
                loss = out.loss

            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch.")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # <-- fixed typo
            opt.step()

            # Collect FD-JVP every N steps
            if step % max(1, jvp_every) == 0:
                dec_scores = compute_decoder_block_jdev(model, dec_bufs, eps=jvp_eps, k_probes=jvp_k)
                for i, v in dec_scores.items():
                    if math.isfinite(v):
                        dec_sum[i] += v
                        dec_cnt[i] += 1

        epoch_dec = {i: dec_sum[i] / dec_cnt[i] for i in dec_sum if dec_cnt[i] > 0}
        print(f"[Epoch {epoch+1}] Decoder Block J-dev (mean): {epoch_dec}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

    remove_hooks(dec_hooks)
    final_dec = {i: dec_sum[i] / dec_cnt[i] for i in dec_sum if dec_cnt[i] > 0}
    return model, final_dec

def prune_and_finetune(model, train_loader, dev_loader, device, tokenizer,
                       dec_jdev_scores, num_prune=num, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Decoder Blocks & Fine-tuning ===")
    _ = prune_decoder_blocks_by_jdev(
        model.decoder.block, dec_jdev_scores, num_prune=num_prune,
        protect_first=True, protect_last=False, verbose=True
    )

    # Keep cache off post-pruning to always recompute biases freshly
    model.config.use_cache = False

    opt = build_optimizer(model)

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast(enabled=False):
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
                loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch.")
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # <-- fixed typo
            opt.step()

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}")
    return model

# --- 6. Entrypoint ---
def main():
    global model  # reuse the earlier-loaded model instance
    model, dec_jdev_scores = full_finetuning_with_jdev(
        model, train_loader, dev_loader, device, tokenizer,
        jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6
    )
    model = prune_and_finetune(
        model, train_loader, dev_loader, device, tokenizer,
        dec_jdev_scores, num_prune=num, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# Only prune decoder — Autograd JVP (NaN-safe, bias-safe)
# ATTENTION-ONLY pruning (self & cross) for T5 on ANLI1

# --- Mount Google Drive if using Colab ---
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# --- Standard Imports ---
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, Adafactor
)
from collections import defaultdict
import warnings, math, random, numpy as np, inspect
from contextlib import contextmanager

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ------------- Repro -------------
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(1234)

# --- 1. Load ANLI1 Dataset ---
data_files = {
    "train":      "/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json",
    "validation": "/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json",
    "test":       "/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Function ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    label_list = ["entailment", "neutral", "contradiction"]
    labels_str = []
    for x in batch['label']:
        sx = str(x).strip().lower()
        if sx.isdigit() and int(sx) < 3: labels_str.append(label_list[int(sx)])
        else: labels_str.append(sx)
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(text_target=labels_str, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

# Tokenizer
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

# Map datasets
train = dataset["train"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["train"].column_names
)
dev = dataset["validation"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["validation"].column_names
)

# --- Load model before creating the collator (so collator can mask label pads -> -100) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
# Avoid kv-cache so biases are always recomputed
model.config.use_cache = False

# Collator that converts pad tokens in labels to -100
collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True,  collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

# =========================
# 3) Autograd-JVP helpers (NaN-safe, bias-safe) — attention sublayers only
# =========================

def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {name: (args[i] if i < len(args) else sig.parameters[name].default) for i, name in enumerate(names)}
    # We will set hidden_states; never reuse cached biases
    d.pop("hidden_states", None)
    d.pop("position_bias", None)
    d.pop("encoder_decoder_position_bias", None)
    return d

def _build_causal_mask(B, q_len, k_len, device, dtype):
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self):
    if mask is None:
        B = 1
        return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
               else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)
    if mask.dim() == 2:
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B,1,q_len,k_len).contiguous().to(device=device, dtype=dtype)
    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
               else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)
    B = 1
    return _build_causal_mask(B, q_len, k_len, device, dtype) if is_self \
           else torch.zeros((B,1,q_len,k_len), dtype=dtype, device=device)

@contextmanager
def force_math_sdp():
    """Force math SDPA kernels so higher-order grads exist; no-op if unavailable."""
    try:
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            yield
    except Exception:
        yield

def _make_attn_callable(module, argdict):
    """Build pure fn f(x)->hidden_states for a single T5 sublayer (self/cross)."""
    raw = dict(argdict) if argdict is not None else {}
    sig = inspect.signature(module.forward)
    params = set(sig.parameters.keys())
    is_cross = "key_value_states" in params

    const_kwargs = {}
    if is_cross:
        kv = raw.get("key_value_states", raw.get("encoder_hidden_states", None))
        const_kwargs["key_value_states"] = None if kv is None else kv.float()

    if "use_cache" in params:          const_kwargs["use_cache"] = False
    if "output_attentions" in params:  const_kwargs["output_attentions"] = False
    if "return_dict" in params:        const_kwargs["return_dict"] = False

    mask_key = "attention_mask" if "attention_mask" in params else ("mask" if "mask" in params else None)
    captured_mask = raw.get(mask_key, raw.get("mask", None)) if mask_key is not None else None
    captured_kv   = const_kwargs.get("key_value_states", None)

    def _f(x):
        x = x.float()
        q_len = x.size(1)
        k_len = captured_kv.size(1) if (is_cross and captured_kv is not None) else q_len
        kwargs = {"hidden_states": x}
        if is_cross:
            if captured_kv is None:  # identity if no encoder states present
                return x
            kwargs["key_value_states"] = captured_kv
        if mask_key is not None:
            kwargs[mask_key] = _ensure_mask_shape(captured_mask, q_len, k_len, x.device, x.dtype, is_self=not is_cross)
        if "cache_position" in params:
            kwargs["cache_position"] = torch.arange(q_len, dtype=torch.long, device=x.device)
        if "query_length" in params and "cache_position" not in params:
            kwargs["query_length"] = q_len

        kwargs.update(const_kwargs)
        kwargs = _filter_kwargs_for_module(module, kwargs)
        with torch.cuda.amp.autocast(enabled=False):
            out = module(**kwargs)
        return out[0] if isinstance(out, (tuple, list)) else out
    return _f

def compute_decoder_attention_jdev_autograd(model, self_bufs, cross_bufs, k_probes=1, rademacher=True):
    """
    For each decoder attention sublayer ℓ, estimate:
      E_v ||(Jℓ - I)v||^2 using autograd JVP.
    Non-finite probes are skipped. Buffers are cleared.
    """
    from torch.autograd.functional import jvp as autograd_jvp

    self_scores, cross_scores = {}, {}

    # Self-attention J-dev
    for idx, buf in self_bufs.items():
        X, args = buf["X"], buf["args"]
        buf["X"], buf["args"] = None, None  # clear ASAP
        if X is None: continue
        mod = model.decoder.block[idx].layer[0]  # T5LayerSelfAttention

        was_training = mod.training
        mod.eval()
        try:
            fn  = _make_attn_callable(mod, args)
            x0  = X.detach().requires_grad_(True).float()
            try:
                with force_math_sdp():
                    _ = fn(x0)
            except Exception:
                continue

            acc, used = 0.0, 0
            with force_math_sdp():
                for _ in range(max(1, k_probes)):
                    if rademacher:
                        v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                    else:
                        v = torch.randn_like(x0)
                    try:
                        _, Jv = autograd_jvp(fn, (x0,), (v,), create_graph=False, strict=True)
                    except Exception:
                        continue
                    if Jv is None or not torch.isfinite(Jv).all():
                        continue
                    jd_vec = Jv - v
                    if not torch.isfinite(jd_vec).all():
                        continue
                    acc += float(jd_vec.pow(2).mean().item()); used += 1
            if used > 0: self_scores[idx] = acc / used
        finally:
            mod.train(was_training)

    # Cross-attention J-dev
    for idx, buf in cross_bufs.items():
        X, args = buf["X"], buf["args"]
        buf["X"], buf["args"] = None, None
        if X is None: continue
        if len(model.decoder.block[idx].layer) <= 1 or model.decoder.block[idx].layer[1] is None:
            continue
        mod = model.decoder.block[idx].layer[1]  # T5LayerCrossAttention

        was_training = mod.training
        mod.eval()
        try:
            fn  = _make_attn_callable(mod, args)
            x0  = X.detach().requires_grad_(True).float()
            try:
                with force_math_sdp():
                    _ = fn(x0)
            except Exception:
                continue

            acc, used = 0.0, 0
            with force_math_sdp():
                for _ in range(max(1, k_probes)):
                    if rademacher:
                        v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                    else:
                        v = torch.randn_like(x0)
                    try:
                        _, Jv = autograd_jvp(fn, (x0,), (v,), create_graph=False, strict=True)
                    except Exception:
                        continue
                    if Jv is None or not torch.isfinite(Jv).all():
                        continue
                    jd_vec = Jv - v
                    if not torch.isfinite(jd_vec).all():
                        continue
                    acc += float(jd_vec.pow(2).mean().item()); used += 1
            if used > 0: cross_scores[idx] = acc / used
        finally:
            mod.train(was_training)

    return self_scores, cross_scores

# =========================
# 4) Hooks to capture ONLY decoder attention inputs
# =========================
def register_decoder_attention_hooks(model):
    dec_blocks = model.decoder.block
    self_bufs  = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    cross_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    hooks = []

    for i, block in enumerate(dec_blocks):
        sa = block.layer[0]
        def pre_hook_sa(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            self_bufs[idx]["X"]    = None if X is None else X.detach()
            self_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        hooks.append(sa.register_forward_pre_hook(pre_hook_sa, with_kwargs=True))

    for i, block in enumerate(dec_blocks):
        if len(block.layer) > 1 and block.layer[1] is not None:
            ca = block.layer[1]
            def pre_hook_ca(module, args, kwargs, idx=i):
                X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
                cross_bufs[idx]["X"]    = None if X is None else X.detach()
                cross_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
            hooks.append(ca.register_forward_pre_hook(pre_hook_ca, with_kwargs=True))

    return hooks, self_bufs, cross_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# =========================
# 5) Attention-only pruning utilities
# =========================
class SkipSelfAttention(nn.Module):
    """Identity replacement for T5LayerSelfAttention (also drop position_bias)."""
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

class SkipCrossAttention(nn.Module):
    """Identity replacement for T5LayerCrossAttention (also drop position_bias)."""
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None)

def prune_decoder_attention(model, self_scores, cross_scores,
                            k_self=2, k_cross=2,
                            protect_first=True, protect_last=False, verbose=True):
    num_layers = len(model.decoder.block)
    self_items  = list(self_scores.items())
    cross_items = list(cross_scores.items())

    if protect_first:
        self_items  = [(i,s) for i,s in self_items  if i != 0]
        cross_items = [(i,s) for i,s in cross_items if i != 0]
    if protect_last:
        self_items  = [(i,s) for i,s in self_items  if i != num_layers-1]
        cross_items = [(i,s) for i,s in cross_items if i != num_layers-1]

    self_items.sort(key=lambda x: x[1])   # lowest J-dev first
    cross_items.sort(key=lambda x: x[1])

    pruned_self, pruned_cross = [], []

    for i, _ in self_items[:max(0, k_self)]:
        model.decoder.block[i].layer[0] = SkipSelfAttention()
        pruned_self.append(i)

    for i, _ in cross_items[:max(0, k_cross)]:
        if len(model.decoder.block[i].layer) > 1 and model.decoder.block[i].layer[1] is not None:
            model.decoder.block[i].layer[1] = SkipCrossAttention()
            pruned_cross.append(i)

    if verbose:
        print(f"Pruned decoder SELF-attention at layers:  {pruned_self}")
        print(f"Pruned decoder CROSS-attention at layers: {pruned_cross}")

    return pruned_self, pruned_cross

# =========================
# 6) Training / Eval helpers
# =========================
CANON = {
    "entailment": "entailment",
    "entailed": "entailment",
    "neutral": "neutral",
    "contradiction": "contradiction",
    "contradict": "contradiction",
    "contradictory": "contradiction",
    "contradicted": "contradiction",
}
def canonicalize_label(s: str):
    s = (s or "").strip().lower()
    first = s.split()[0] if s else s
    return CANON.get(first, first)

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if canonicalize_label(p) == canonicalize_label(l):
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    model.config.use_cache = False
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # decode gold labels to text
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4,
            use_cache=False,
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def build_optimizer(m):
    return Adafactor(m.parameters(), relative_step=True, scale_parameter=True, warmup_init=True)

# =========================
# 7) Train → collect Autograd-JVP (attention) → prune → FT
# =========================
def full_finetuning_collect_attn_jdev(model, train_loader, dev_loader, device, tokenizer,
                                      jvp_k=1, jvp_every=1, epochs=4):
    print("=== Stage 1: Full FT & Autograd-JVP (Decoder Attention only) ===")
    model.config.use_cache = False
    opt = build_optimizer(model)

    hooks, self_bufs, cross_bufs = register_decoder_attention_hooks(model)
    self_sum, self_cnt   = defaultdict(float), defaultdict(int)
    cross_sum, cross_cnt = defaultdict(float), defaultdict(int)

    step = 0
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()

            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch."); continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            if step % max(1, jvp_every) == 0:
                self_scores, cross_scores = compute_decoder_attention_jdev_autograd(
                    model, self_bufs, cross_bufs, k_probes=jvp_k, rademacher=True
                )
                for i, v in self_scores.items():
                    if math.isfinite(v): self_sum[i]  += v; self_cnt[i]  += 1
                for i, v in cross_scores.items():
                    if math.isfinite(v): cross_sum[i] += v; cross_cnt[i] += 1

        epoch_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum  if self_cnt[i]  > 0}
        epoch_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum if cross_cnt[i] > 0}
        print(f"[Epoch {epoch+1}] Decoder SELF-attn J-dev:  {epoch_self}")
        print(f"[Epoch {epoch+1}] Decoder CROSS-attn J-dev: {epoch_cross}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

    remove_hooks(hooks)
    final_self  = {i: self_sum[i]/self_cnt[i]   for i in self_sum   if self_cnt[i]   > 0}
    final_cross = {i: cross_sum[i]/cross_cnt[i] for i in cross_sum  if cross_cnt[i]  > 0}
    return model, final_self, final_cross

def prune_attention_and_finetune(model, train_loader, dev_loader, device, tokenizer,
                                 self_scores, cross_scores, k_self=num, k_cross=0, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Attention (self/cross) & Fine-tuning ===")
    _self, _cross = prune_decoder_attention(
        model, self_scores, cross_scores,
        k_self=k_self, k_cross=k_cross, protect_first=True, protect_last=False, verbose=True
    )

    model.config.use_cache = False
    opt = build_optimizer(model)

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch."); continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}")
    return model, _self, _cross

# --- 8. Entrypoint ---
def main():
    global model
    model, self_scores, cross_scores = full_finetuning_collect_attn_jdev(
        model, train_loader, dev_loader, device, tokenizer,
        jvp_k=1, jvp_every=1, epochs=6
    )
    model, pruned_self, pruned_cross = prune_attention_and_finetune(
        model, train_loader, dev_loader, device, tokenizer,
        self_scores, cross_scores, k_self=num, k_cross=0, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# ===========================
# 0. Google Drive Mount
# ===========================
from google.colab import drive
drive.mount('/content/drive')

# ===========================
# 1. Imports and Setup
# ===========================
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq
)
from collections import defaultdict
import warnings, math, inspect, random, numpy as np

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Repro
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(1234)

# ===========================
# 2. Load SVAMP Dataset
# ===========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# ===========================
# 3. Preprocessing
# ===========================
def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):
    model_inputs = tokenizer(
        batch["input"], padding="max_length", truncation=True, max_length=max_input_length
    )
    targets = [str(x) for x in batch["label"]]
    target_encodings = tokenizer(
        targets, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target_encodings["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_svamp(ex, tokenizer),  batched=True, remove_columns=dataset["test"].column_names)

# Collator (pad labels to -100 automatically if model is passed; safe with default too)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=8, shuffle=False, collate_fn=collator)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===========================
# 4) FD-JVP helpers (NaN-safe) — capture + deterministic replay for T5Block
# ===========================

def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    """Normalize to kwargs for re-calling module.forward."""
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            d[name] = args[i] if i < len(args) else sig.parameters[name].default
    d.pop("hidden_states", None)  # we set explicitly
    return d

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # 0 for allowed, large negative for disallowed (stable vs -inf)
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _build_zero_mask(B, q_len, k_len, device, dtype):
    return torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self, B_fallback=1):
    """
    Produce a valid (B,1,q_len,k_len) mask; rebuild default if missing/invalid.
    Self-attn -> causal mask; Cross-attn -> zero mask.
    """
    default_fn = _build_causal_mask if is_self else _build_zero_mask

    if mask is None:
        return default_fn(B_fallback, q_len, k_len, device, dtype)

    if mask.dim() == 2:  # (B, K)
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B, 1, q_len, k_len).contiguous().to(device=device, dtype=dtype)

    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        return default_fn(B, q_len, k_len, device, dtype)

    return default_fn(B_fallback, q_len, k_len, device, dtype)

@torch.no_grad()
def _t5block_forward_only(block, X, argdict):
    """
    Deterministic single T5Block forward in fp32 (autocast disabled).
    Provides safe masks & cache_position/query_length when missing.
    Returns only hidden_states tensor (or None if non-finite).
    """
    if X is None:
        return None
    was_training = block.training
    block.eval()
    Xf = X.float()
    try:
        raw = dict(argdict) if argdict is not None else {}
        kwargs = {"hidden_states": Xf}

        q_len = Xf.size(1)
        B = Xf.size(0)

        # Self-attention mask at block level
        attn_mask = raw.get("attention_mask", raw.get("mask", None))
        kwargs["attention_mask"] = _ensure_mask_shape(
            attn_mask, q_len, q_len, Xf.device, Xf.dtype, is_self=True, B_fallback=B
        )

        # Cross-attention (if encoder states present)
        enc_states = raw.get("encoder_hidden_states", None)
        if enc_states is not None:
            enc_states = enc_states.float()
            kwargs["encoder_hidden_states"] = enc_states
            k_len = enc_states.size(1)
            enc_mask = raw.get("encoder_attention_mask", None)
            kwargs["encoder_attention_mask"] = _ensure_mask_shape(
                enc_mask, q_len, k_len, Xf.device, Xf.dtype, is_self=False, B_fallback=B
            )

        # Optional signature args (version-dependent)
        params = set(inspect.signature(block.forward).parameters.keys())
        if "cache_position" in params:
            cp = raw.get("cache_position", None)
            if cp is None:
                cp = torch.arange(q_len, dtype=torch.long, device=Xf.device)
            kwargs["cache_position"] = cp
        if "query_length" in params:
            kwargs["query_length"] = q_len
        if "use_cache" in params:
            kwargs["use_cache"] = False
        if "output_attentions" in params:
            kwargs["output_attentions"] = False
        if "return_dict" in params:
            kwargs["return_dict"] = False

        kwargs = _filter_kwargs_for_module(block, kwargs)

        with torch.cuda.amp.autocast(enabled=False):
            out = block(**kwargs)

        hs = out[0] if isinstance(out, (tuple, list)) else out
        return hs if hs is None or torch.isfinite(hs).all() else None
    finally:
        block.train(was_training)

def register_decoder_block_jvp_hooks(model):
    """
    Capture per-decoder-block incoming hidden_states + kwargs with forward_pre_hook.
    """
    dec_blocks = model.decoder.block
    dec_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    dec_hooks = []
    for i, block in enumerate(dec_blocks):
        def pre_hook(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            dec_bufs[idx]["X"]    = None if X is None else X.detach()
            dec_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        dec_hooks.append(block.register_forward_pre_hook(pre_hook, with_kwargs=True))
    return dec_hooks, dec_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def compute_decoder_block_jdev(model, dec_bufs, eps=1e-3, k_probes=1):
    """
    For each decoder block ℓ, estimate E_v ||(Jℓ - I)v||^2 via finite differences:
      (J - I)v ≈ (f(x + eps v) - f(x))/eps - v
    NaN-safe: skips any non-finite forward/probe results. Clears buffers.
    """
    scores = {}
    for idx, buf in dec_bufs.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue
        block = model.decoder.block[idx]
        X = X.float()

        y0 = _t5block_forward_only(block, X, args)
        if y0 is None:
            buf["X"], buf["args"] = None, None
            continue

        acc, used = 0.0, 0
        for _ in range(max(1, k_probes)):
            v = torch.randn_like(X)
            y_eps = _t5block_forward_only(block, X + eps * v, args)
            if y_eps is None:
                continue
            jd_vec = (y_eps - y0) / eps - v
            if not torch.isfinite(jd_vec).all():
                continue
            acc += float(jd_vec.pow(2).mean().item())
            used += 1

        if used > 0:
            scores[idx] = acc / used

        buf["X"], buf["args"] = None, None
    return scores

# ===========================
# 5) Pruning Utilities (decoder blocks)
# ===========================
class SkipBlock(nn.Module):
    """
    Identity replacement for a T5 decoder block. Keeps tuple shape:
    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,
     position_bias, encoder_decoder_position_bias)
    """
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=False,
        cache_position=None,
        **kwargs,
    ):
        return (hidden_states, None, None, None, position_bias, encoder_decoder_position_bias)

def prune_decoder_blocks_by_jdev(blocks, jdev_scores, num_prune=4, protect_first=True, protect_last=False, verbose=True):
    """
    Prune LOW J-dev (near-identity) decoder blocks.
    """
    if not jdev_scores:
        if verbose: print("No J-dev scores available; skipping pruning.")
        return []
    items = list(jdev_scores.items())
    if protect_first:
        items = [(i, s) for i, s in items if i != 0]
    if protect_last:
        items = [(i, s) for i, s in items if i != len(blocks) - 1]
    items.sort(key=lambda x: x[1])  # lowest first
    prune_idxs = [i for i, _ in items[:max(0, num_prune)]]
    for i in prune_idxs:
        blocks[i] = SkipBlock()
    if verbose:
        print(f"Pruned decoder blocks (lowest J-dev): {prune_idxs}")
    return prune_idxs

# ===========================
# 6) Eval Helper
# ===========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if str(p).strip().lower() == str(l).strip().lower():
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # decode gold labels to text
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=8,
            use_cache=False  # stay consistent with SkipBlock semantics
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# ===========================
# 7) Training + FD-JVP + Pruning
# ===========================
def full_finetuning_with_jdev(train_loader, dev_loader, device, tokenizer,
                              jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6):
    """
    Stage 1: full FT while periodically collecting FD-JVP J-dev for decoder blocks.
    """
    print("=== Stage 1: Full FT + FD-JVP (Decoder Blocks) ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    # Avoid kv-cache complexity in replay and in SkipBlock
    model.config.use_cache = False

    # Simple, stable optimizer (AdamW with small LR also fine)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

    dec_hooks, dec_bufs = register_decoder_block_jvp_hooks(model)
    dec_sum, dec_cnt = defaultdict(float), defaultdict(int)
    step = 0

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()

            # Disable AMP for stability throughout FD-JVP collection phase
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch.")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            # Collect FD-JVP every N steps
            if step % max(1, jvp_every) == 0:
                dec_scores = compute_decoder_block_jdev(model, dec_bufs, eps=jvp_eps, k_probes=jvp_k)
                for i, v in dec_scores.items():
                    if math.isfinite(v):
                        dec_sum[i] += v
                        dec_cnt[i] += 1

        epoch_dec = {i: dec_sum[i] / dec_cnt[i] for i in dec_sum if dec_cnt[i] > 0}
        print(f"[Epoch {epoch+1}] Decoder Block J-dev (mean): {epoch_dec}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

    remove_hooks(dec_hooks)
    final_dec = {i: dec_sum[i] / dec_cnt[i] for i in dec_sum if dec_cnt[i] > 0}
    return model, final_dec

def prune_and_finetuning(model, train_loader, dev_loader, device, tokenizer,
                         dec_jdev_scores, num_prune=num, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Decoder Blocks & Fine-tuning ===")
    _ = prune_decoder_blocks_by_jdev(
        model.decoder.block, dec_jdev_scores, num_prune=num_prune,
        protect_first=True, protect_last=False, verbose=True
    )

    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch.")
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}")
    return model

# ===========================
# 8. Main Entrypoint
# ===========================
def main():
    model, dec_jdev_scores = full_finetuning_with_jdev(
        train_loader, dev_loader, device, tokenizer,
        jvp_eps=1e-3, jvp_k=1, jvp_every=1, epochs=6
    )
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device, tokenizer,
        dec_jdev_scores, num_prune=num, epochs=5
    )

if __name__ == "__main__":
    main()


In [None]:
# Autograd JVP, SVAMP


# ===========================
# 0) (Optional) Google Drive Mount (no-op off Colab)
# ===========================
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# ===========================
# 1) Imports and Setup
# ===========================
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq
)
from collections import defaultdict
import warnings, math, inspect, random, numpy as np
from contextlib import nullcontext

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Repro
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(1234)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===========================
# 2) Load SVAMP Dataset
# ===========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# ===========================
# 3) Preprocessing
# ===========================
def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):
    model_inputs = tokenizer(
        batch["input"], padding="max_length", truncation=True, max_length=max_input_length
    )
    targets = [str(x) for x in batch["label"]]
    target_encodings = tokenizer(
        targets, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target_encodings["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_svamp(ex, tokenizer),
                             batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_svamp(ex, tokenizer),
                             batched=True, remove_columns=dataset["test"].column_names)

# Collator (mask label pads to -100 even without passing model)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, label_pad_token_id=-100,
                                  padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=8, shuffle=False, collate_fn=collator)

# ===========================
# 4) Autograd-JVP helpers — capture + deterministic replay for T5Block
# ===========================

def _filter_kwargs_for_module(module, kwargs):
    sig = inspect.signature(module.forward)
    allowed = set(sig.parameters.keys()) - {"self"}
    return {k: v for k, v in kwargs.items() if k in allowed}

def _map_args_from_hook(module, args, kwargs):
    """Normalize to kwargs for re-calling module.forward; drop cached biases."""
    if kwargs and len(kwargs) > 0:
        d = dict(kwargs)
    else:
        sig = inspect.signature(module.forward)
        names = [n for n in sig.parameters.keys() if n != "self"]
        d = {}
        for i, name in enumerate(names):
            d[name] = args[i] if i < len(args) else sig.parameters[name].default
    d.pop("hidden_states", None)                  # we set explicitly
    d.pop("position_bias", None)                  # never reuse cached biases
    d.pop("encoder_decoder_position_bias", None)  # never reuse cached biases
    return d

def _build_causal_mask(B, q_len, k_len, device, dtype):
    # stable large negative (not -inf) to avoid NaNs in softmax
    mask = torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)
    tri = torch.triu(torch.ones((q_len, k_len), dtype=torch.bool, device=device), diagonal=1)
    mask[:, :, tri] = -1e9
    return mask

def _build_zero_mask(B, q_len, k_len, device, dtype):
    return torch.zeros((B, 1, q_len, k_len), dtype=dtype, device=device)

def _ensure_mask_shape(mask, q_len, k_len, device, dtype, is_self, B_fallback=1):
    """
    Produce a valid (B,1,q_len,k_len) mask; rebuild default if missing/invalid.
    Self-attn -> causal mask; Cross-attn -> zero mask.
    """
    default_fn = _build_causal_mask if is_self else _build_zero_mask

    if mask is None:
        return default_fn(B_fallback, q_len, k_len, device, dtype)

    if mask.dim() == 2:  # (B, K)
        B = mask.size(0)
        return mask[:, None, None, :k_len].expand(B, 1, q_len, k_len).contiguous().to(device=device, dtype=dtype)

    if mask.dim() == 4:
        B, _, Q, K = mask.shape
        if Q >= q_len and K >= k_len:
            return mask[:, :, :q_len, :k_len].to(device=device, dtype=dtype)
        return default_fn(B, q_len, k_len, device, dtype)

    return default_fn(B_fallback, q_len, k_len, device, dtype)

def _sdp_math_ctx():
    """
    Prefer math (non-Flash/MemEfficient) SDP kernels so higher-order grads exist.
    Falls back to nullcontext if backend is unavailable (e.g., CPU).
    """
    try:
        return torch.backends.cuda.sdp_kernel(enable_flash=False,
                                              enable_mem_efficient=False,
                                              enable_math=True)
    except Exception:
        return nullcontext()

def _make_block_callable(block, argdict):
    """
    Build pure fn f(x)->hs for a single T5 decoder block.
    Rebuild masks; never reuse captured position biases.
    """
    raw = dict(argdict) if argdict is not None else {}

    # never reuse cached biases (shape-mismatch risk during generation)
    raw.pop("position_bias", None)
    raw.pop("encoder_decoder_position_bias", None)

    params = set(inspect.signature(block.forward).parameters.keys())

    # captured sources for masks & encoder states
    attn_mask_src = raw.get("attention_mask", raw.get("mask", None))
    enc_states    = raw.get("encoder_hidden_states", None)
    enc_mask_src  = raw.get("encoder_attention_mask", None)

    const_kwargs = {}
    if enc_states is not None:
        const_kwargs["encoder_hidden_states"] = enc_states.float()

    # deterministic flags
    if "use_cache" in params:
        const_kwargs["use_cache"] = False
    if "output_attentions" in params:
        const_kwargs["output_attentions"] = False
    if "return_dict" in params:
        const_kwargs["return_dict"] = False

    def _f(x):
        x = x.float()
        q_len = x.size(1)
        B     = x.size(0)
        kwargs = {"hidden_states": x}

        # self mask
        kwargs["attention_mask"] = _ensure_mask_shape(
            attn_mask_src, q_len, q_len, x.device, x.dtype, is_self=True, B_fallback=B
        )

        # cross (optional)
        if "encoder_hidden_states" in const_kwargs:
            kv = const_kwargs["encoder_hidden_states"]
            k_len = kv.size(1)
            kwargs["encoder_hidden_states"] = kv
            kwargs["encoder_attention_mask"] = _ensure_mask_shape(
                enc_mask_src, q_len, k_len, x.device, x.dtype, is_self=False, B_fallback=B
            )

        # positional args (version-dependent)
        if "cache_position" in params:
            kwargs["cache_position"] = torch.arange(q_len, dtype=torch.long, device=x.device)
        if "query_length" in params and "cache_position" not in params:
            kwargs["query_length"] = q_len

        kwargs.update(const_kwargs)
        kwargs = _filter_kwargs_for_module(block, kwargs)

        # Use math SDP to keep higher-order derivatives available
        with _sdp_math_ctx():
            out = block(**kwargs)

        hs = out[0] if isinstance(out, (tuple, list)) else out
        return hs

    return _f

def register_decoder_block_jvp_hooks(model):
    """
    Capture per-decoder-block incoming hidden_states + kwargs with forward_pre_hook.
    """
    dec_blocks = model.decoder.block
    dec_bufs = {i: {"X": None, "args": None} for i in range(len(dec_blocks))}
    dec_hooks = []
    for i, block in enumerate(dec_blocks):
        def pre_hook(module, args, kwargs, idx=i):
            X = kwargs.get("hidden_states", args[0] if len(args) > 0 else None)
            dec_bufs[idx]["X"]    = None if X is None else X.detach()
            dec_bufs[idx]["args"] = _map_args_from_hook(module, args, kwargs)
        dec_hooks.append(block.register_forward_pre_hook(pre_hook, with_kwargs=True))
    return dec_hooks, dec_bufs

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_decoder_block_jdev_autograd(model, dec_bufs, k_probes=2, rademacher=True):
    """
    Autograd/func JVP estimator of E_v ||(J - I)v||^2 for each decoder block.
    - rademacher=True uses ±1 probes; else Gaussian.
    - Returns dict {layer_idx: score}.
    """
    # Prefer torch.func.jvp in PyTorch 2+, else fall back to autograd.functional.jvp
    try:
        from torch.func import jvp as func_jvp
        use_func = True
    except Exception:
        from torch.autograd.functional import jvp as autograd_jvp
        use_func = False

    scores = {}

    for idx, buf in dec_bufs.items():
        X, args = buf["X"], buf["args"]
        # clear buffers early to limit memory
        buf["X"], buf["args"] = None, None

        if X is None:
            continue

        block = model.decoder.block[idx]
        was_training = block.training
        block.eval()  # no dropout for determinism

        try:
            f = _make_block_callable(block, args)
            x0 = X.detach().requires_grad_(True).float()

            # quick health check
            try:
                with _sdp_math_ctx():
                    y0 = f(x0)
            except Exception:
                continue
            if y0 is None or not torch.isfinite(y0).all():
                continue

            acc, used = 0.0, 0
            for _ in range(max(1, k_probes)):
                if rademacher:
                    v = torch.empty_like(x0).bernoulli_(0.5).mul_(2).sub_(1)
                else:
                    v = torch.randn_like(x0)

                try:
                    if use_func:
                        _, Jv = func_jvp(f, (x0,), (v,))
                    else:
                        _, Jv = autograd_jvp(f, (x0,), (v,), create_graph=False, strict=False)
                except Exception:
                    continue

                if Jv is None or not torch.isfinite(Jv).all():
                    continue

                jd = Jv - v
                if not torch.isfinite(jd).all():
                    continue

                acc += float(jd.pow(2).mean().item())
                used += 1

            if used > 0:
                scores[idx] = acc / used
        finally:
            block.train(was_training)

    return scores

# ===========================
# 5) Pruning Utilities (decoder blocks)
# ===========================
class SkipBlock(nn.Module):
    """
    Identity replacement for a T5 decoder block.
    Return None for biases so later layers recompute them with correct (Q,K) sizes.
    Keeps tuple shape:
    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,
     position_bias, encoder_decoder_position_bias)
    """
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=False,
        cache_position=None,
        **kwargs,
    ):
        return (hidden_states, None, None, None, None, None)

def prune_decoder_blocks_by_jdev(blocks, jdev_scores, num_prune=4, protect_first=True, protect_last=False, verbose=True):
    """
    Prune LOW J-dev (near-identity) decoder blocks.
    """
    if not jdev_scores:
        if verbose: print("No J-dev scores available; skipping pruning.")
        return []
    items = list(jdev_scores.items())
    if protect_first:
        items = [(i, s) for i, s in items if i != 0]
    if protect_last:
        items = [(i, s) for i, s in items if i != len(blocks) - 1]
    items.sort(key=lambda x: x[1])  # lowest first
    prune_idxs = [i for i, _ in items[:max(0, num_prune)]]
    for i in prune_idxs:
        blocks[i] = SkipBlock()
    if verbose:
        print(f"Pruned decoder blocks (lowest J-dev): {prune_idxs}")
    return prune_idxs

# ===========================
# 6) Eval Helper
# ===========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if str(p).strip().lower() == str(l).strip().lower():
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0.0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    model.config.use_cache = False  # recompute biases fresh
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # decode gold labels to text
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=8,
            use_cache=False  # consistent with SkipBlock
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# ===========================
# 7) Training + Autograd-JVP + Pruning
# ===========================
def full_finetuning_with_jdev(train_loader, dev_loader, device, tokenizer,
                              jvp_k=2, jvp_every=1, rademacher=True, epochs=6):
    """
    Stage 1: full FT while periodically collecting Autograd-JVP J-dev for decoder blocks.
    """
    print("=== Stage 1: Full FT + Autograd-JVP (Decoder Blocks) ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    model.config.use_cache = False

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

    dec_hooks, dec_bufs = register_decoder_block_jvp_hooks(model)
    dec_sum, dec_cnt = defaultdict(float), defaultdict(int)
    step = 0

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            step += 1
            opt.zero_grad()

            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch.")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            # Collect JVP every N steps
            if step % max(1, jvp_every) == 0:
                dec_scores = compute_decoder_block_jdev_autograd(
                    model, dec_bufs, k_probes=jvp_k, rademacher=rademacher
                )
                for i, v in dec_scores.items():
                    if math.isfinite(v):
                        dec_sum[i] += v
                        dec_cnt[i] += 1

        epoch_dec = {i: dec_sum[i] / dec_cnt[i] for i in dec_sum if dec_cnt[i] > 0}
        print(f"[Epoch {epoch+1}] Decoder Block J-dev (mean): {epoch_dec}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

    remove_hooks(dec_hooks)
    final_dec = {i: dec_sum[i] / dec_cnt[i] for i in dec_sum if dec_cnt[i] > 0}
    return model, final_dec

def prune_and_finetuning(model, train_loader, dev_loader, device, tokenizer,
                         dec_jdev_scores, num_prune=4, epochs=5):
    print("=== Stage 2: Prune LOW-Jdev Decoder Blocks & Fine-tuning ===")
    _ = prune_decoder_blocks_by_jdev(
        model.decoder.block, dec_jdev_scores, num_prune=num_prune,
        protect_first=True, protect_last=False, verbose=True
    )

    model.config.use_cache = False
    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            loss = out.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping batch.")
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}")
    return model

# ===========================
# 8) Main Entrypoint
# ===========================
def main():
    model, dec_jdev_scores = full_finetuning_with_jdev(
        train_loader, dev_loader, device, tokenizer,
        jvp_k=2, jvp_every=1, rademacher=True, epochs=6
    )
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device, tokenizer,
        dec_jdev_scores, num_prune=num, epochs=5
    )

if __name__ == "__main__":
    main()
