In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate


Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
Collecting datasets==2.18.0
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.18.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.2.0-py3-none-any.whl (170 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow_hotfix-0.7-py3-none-any.whl (7.9 kB)
Installing collected packages: pyarrow-hotfix, fsspec, datasets
  Attempting uninstall: fsspe

In [None]:
num = 6  # Define the number of Layers to be pruned

In [None]:
# Prune Layers Based on Jacobian Deviation (||J - I||) Using MNLI Dataset
# FD

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Jacobian Deviation Hooks/Scoring
#    Compare each layer's mapping f_ℓ via ||J_ℓ - I|| (Hutchinson-style)
# ========================================================

def _flatten_tokens(x):
    # x: (batch, seq, hidden) or (batch, hidden)
    if x.dim() == 3:
        b, s, h = x.shape
        return x.reshape(b * s, h)
    elif x.dim() == 2:
        return x
    return x.view(x.size(0), -1)

def register_jac_hooks(model):
    """
    Capture each RobertaLayer's block input (pre), output (post), and non-input args.
    We'll call the layer again locally to estimate JVPs via finite differences.
    """
    layers = model.roberta.encoder.layer
    # store X, Y, args (everything after hidden_states), and a flag
    activations = {i: {'X': None, 'Y': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            # inputs: (hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, ...)
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            # args can contain None/bool/tensors; keep as-is (no detach on non-tensors)
            activations[idx]['args'] = extra

        def post_hook(module, inputs, output, idx=i):
            y = output[0] if isinstance(output, tuple) else output
            activations[idx]['Y'] = y.detach()

        hooks.append(layer.register_forward_pre_hook(pre_hook))
        hooks.append(layer.register_forward_hook(post_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps=1e-3, k_probes=1):
    """
    For each layer, estimate ||J - I||_F^2 via finite-difference JVPs:
      (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    We average squared norms over k_probes random v and all elements.
    Returns dict: {layer_idx: score}
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; Y = buf['Y']; args = buf['args']
        if X is None or Y is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable finite-diff
        try:
            # Recompute base output with dropout off to pair with the eps-perturbed run
            y0 = _layer_forward_only(layer, X, args)

            # Random-probe average
            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(X)
                y_eps = _layer_forward_only(layer, X + eps * v, args)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())
            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # Clear buffers so we don't reuse stale tensors
        buf['X'] = None
        buf['Y'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-JacDev)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # Pass-through the residual (identity)
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    We remove the FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1], reverse=False)  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['premise'],
               examples['hypothesis'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=3).to(device)
    model.gradient_checkpointing_enable()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)

    last_jac = None
    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device),
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch Jacobian deviation scores
            batch_jac = compute_batch_jacdev(model, activations, eps=1e-3, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx] += v
                jac_counts[idx] += 1

        epoch_jac = {idx: (jac_sums[idx] / max(1, jac_counts[idx])) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MNLI Acc: {acc:.4f}")
    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)
    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    for epoch in range(3):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(
                input_ids=b['input_ids'].to(device),
                attention_mask=b['attention_mask'].to(device),
                labels=b['labels'].to(device),
            )
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}")
    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters():
        p.requires_grad = False
    for p in model.classifier.parameters():
        p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    scaler = GradScaler()
    for epoch in range(3):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device),
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "mnli", split="train[:5000]").shuffle(seed)
    dev_ds   = load_dataset("glue", "mnli", split="validation_matched[:1000]")
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["premise","hypothesis","idx"]) \
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["premise","hypothesis","idx"]) \
                  .rename_column("label","labels")

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Prune Layers Based on Jacobian Deviation (||J - I||) Using MNLI Dataset
# - Autograd JVP scorer with SDPA-math fallback, AD
# - Finite-difference fallback if higher-order grads are unavailable

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings
from contextlib import contextmanager

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    On CPU, this is a no-op.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian Deviation Hooks/Scoring
#    Compare each layer's mapping f_ℓ via ||J_ℓ - I|| (Hutchinson-style)
# ========================================================

def _flatten_tokens(x):
    # x: (batch, seq, hidden) or (batch, hidden)
    if x.dim() == 3:
        b, s, h = x.shape
        return x.reshape(b * s, h)
    elif x.dim() == 2:
        return x
    return x.view(x.size(0), -1)

def register_jac_hooks(model):
    """
    Capture each RobertaLayer's block input (pre), output (post), and non-input args.
    We'll call the layer again locally to estimate JVPs or finite-diff JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'Y': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            # inputs: (hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, ...)
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra

        def post_hook(module, inputs, output, idx=i):
            y = output[0] if isinstance(output, tuple) else output
            activations[idx]['Y'] = y.detach()

        hooks.append(layer.register_forward_pre_hook(pre_hook))
        hooks.append(layer.register_forward_hook(post_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

# -------- Autograd JVP scorer (no finite differences) --------
def compute_batch_jacdev_autograd(model, acts, k_probes=1, rademacher=True, max_tokens=None):
    """
    Exact JVP version (no finite differences).
    For each layer l, estimate E_v ||(J_l - I) v||^2 over k_probes random vectors.
      - If rademacher=True: v ~ {+1,-1}
      - Else: v ~ N(0, I)
      - Optionally subsample tokens to reduce compute: max_tokens keeps at most that many (B*S) rows.
    Returns: {layer_idx: score}
    """
    from torch.autograd.functional import jvp

    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X, args = buf['X'], buf['args']
        if X is None:
            continue

        # Optional token subsampling for speed
        if X.dim() == 3:
            b, s, h = X.shape
            X2 = X.reshape(b*s, h)
            if max_tokens is not None and X2.size(0) > max_tokens:
                sel = torch.randperm(X2.size(0), device=X2.device)[:max_tokens]
                X = X2[sel].unsqueeze(0)  # (1, max_tokens, H)
                if len(args) > 0 and isinstance(args[0], torch.Tensor):
                    attn = torch.ones((1, X.size(1)), dtype=args[0].dtype, device=X.device)
                    args = (attn,) + args[1:]

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Force math attention so JVP has higher-order grad support
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1
                    else:
                        v = torch.randn_like(X)

                    # JVP: returns (f(X), Jv)
                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd_vec = Jv - v
                    acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes

        finally:
            layer.train(was_training)

        # Clear to avoid stale tensors
        buf['X'] = None
        buf['Y'] = None
        buf['args'] = None

    return scores

# -------- Finite-difference fallback scorer --------
@torch.no_grad()
def compute_batch_jacdev_fdiff(model, acts, eps=1e-3, k_probes=1):
    """
    Finite differences: (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X, args = buf['X'], buf['args']
        if X is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()
        try:
            y0 = _layer_forward_only(layer, X, args)
            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(X)
                y_eps = _layer_forward_only(layer, X + eps * v, args)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())
            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        buf['X'] = None
        buf['Y'] = None
        buf['args'] = None

    return scores

# -------- Safe wrapper that tries JVP then falls back --------
def compute_batch_jacdev_safe(model, activations, **kw):
    try:
        return compute_batch_jacdev_autograd(model, activations, **kw)
    except RuntimeError as e:
        # Efficient SDPA kernel lacks higher-order grads -> fallback
        if "scaled_dot_product_efficient_attention_backward" in str(e):
            return compute_batch_jacdev_fdiff(model, activations, eps=1e-3, k_probes=max(1, kw.get("k_probes", 1)))
        else:
            raise

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-JacDev)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # Pass-through the residual (identity)
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    We remove the FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1], reverse=False)  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['premise'],
               examples['hypothesis'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=3).to(device)
    model.gradient_checkpointing_enable()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)

    last_jac = None
    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device),
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch Jacobian deviation scores (safe JVP with fallback)
            batch_jac = compute_batch_jacdev_safe(model, activations, k_probes=2, rademacher=True, max_tokens=1024)
            for idx, v in batch_jac.items():
                jac_sums[idx] += v
                jac_counts[idx] += 1

        epoch_jac = {idx: (jac_sums[idx] / max(1, jac_counts[idx])) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MNLI Acc: {acc:.4f}")
    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)
    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    for epoch in range(3):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(
                input_ids=b['input_ids'].to(device),
                attention_mask=b['attention_mask'].to(device),
                labels=b['labels'].to(device),
            )
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}")
    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters():
        p.requires_grad = False
    for p in model.classifier.parameters():
        p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    scaler = GradScaler()
    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device),
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "mnli", split="train[:5000]").shuffle(seed)
    dev_ds   = load_dataset("glue", "mnli", split="validation_matched[:1000]")
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["premise","hypothesis","idx"]) \
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["premise","hypothesis","idx"]) \
                  .rename_column("label","labels")

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian Deviation, FD

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Jacobian Deviation / Hook Utilities (per-layer)
# ========================================================
def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input (hidden_states) and extra args.
    We'll locally re-run the layer to estimate JVPs via finite differences.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps=1e-3, k_probes=1):
    """
    For each layer ℓ, estimate ||Jℓ - I||_F^2 using finite differences:
      (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    Average over k_probes random v and all elements.
    Returns {layer_idx: score}.
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable estimate
        try:
            y0 = _layer_forward_only(layer, X, args)

            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(X)
                y_eps = _layer_forward_only(layer, X + eps * v, args)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())
            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    Replace FFN with Identity and use residual passthrough.
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1], reverse=False)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using Jacobian deviation)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # per-batch ||J - I||^2 per layer
            batch_jac = compute_batch_jacdev(model, activations, eps=1e-3, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/jac_counts[idx]
                     for idx in jac_sums if jac_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx ||J - I||^2:", epoch_jac)
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MRPC Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac


def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess MRPC subset
    train_ds = load_dataset("glue", "mrpc", split="train")\
               .shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "mrpc", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian Deviation with Autograd JVP (||J - I||) on GLUE MRPC, AD
# - Uses torch.autograd.functional.jvp for (J - I)v
# - Forces math SDPA during JVP to avoid missing higher-order grads

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings
from contextlib import contextmanager

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    No-op on CPU.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian Deviation / Hook Utilities (per-layer)
# ========================================================
def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input (hidden_states) and extra args.
    We'll locally re-run the layer to estimate JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

# (kept for completeness; not used by JVP path)
@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

def compute_batch_jacdev_autograd(model, acts, k_probes=2, rademacher=True):
    """
    Autograd JVP estimator of E_v ||(J - I)v||^2 for each Transformer layer.
    - rademacher=True uses ±1 probes; else Gaussian.
    - Returns dict {layer_idx: score}.
    """
    from torch.autograd.functional import jvp
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X, args = buf['X'], buf['args']
        if X is None:
            continue

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # turn off dropout for determinism

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Avoid efficient SDPA kernels lacking higher-order grads
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1
                    else:
                        v = torch.randn_like(X)

                    # JVP: returns (f(X), Jv). We only need Jv
                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd = Jv - v
                    acc += float(jd.pow(2).mean().item())

            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # Residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    Replace FFN with Identity and use residual passthrough.
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (using Jacobian deviation)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device)
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer via JVP (dropout off inside function)
            batch_jac = compute_batch_jacdev_autograd(
                model, activations, k_probes=2, rademacher=True
            )
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/jac_counts[idx]
                     for idx in jac_sums if jac_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx ||J - I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MRPC Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(
                input_ids=b['input_ids'].to(device),
                attention_mask=b['attention_mask'].to(device),
                labels=b['labels'].to(device)
            )
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters():
        p.requires_grad = False
    for p in model.classifier.parameters():
        p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device)
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess MRPC subset
    train_ds = load_dataset("glue", "mrpc", split="train").shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "mrpc", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian, FD

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Jacobian Deviation (||J - I||) / Hook Utilities
#    Capture each layer's input so we can locally re-run it
#    to estimate JVPs via finite differences.
# ========================================================
def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def register_jac_hooks(model):
    """
    Store each RobertaLayer's input (hidden_states) and extra args
    (attention mask, etc.). We'll locally call the layer again to
    compute finite-difference JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra
        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    # Disable AMP and do fp32 for numeric stability
    with torch.cuda.amp.autocast(enabled=False):
        out = layer(x.to(torch.float32), *args)
        return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps=1e-3, k_probes=1):
    """
    For each layer ℓ, estimate ||Jℓ - I||_F^2:
      (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    Average squared norm over k_probes random v.
    Returns {layer_idx: score}.
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # turn off dropout during the probes
        try:
            y0 = _layer_forward_only(layer, X, args)

            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(X)
                y_eps = _layer_forward_only(layer, X + eps * v, args)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())
            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers for next batch
        buf['X'] = None
        buf['args'] = None

    return scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=6):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    Replace FFN with Identity and use residual passthrough.
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1], reverse=False)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer
            batch_jac = compute_batch_jacdev(model, activations, eps=1e-3, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/jac_counts[idx]
                     for idx in jac_sums if jac_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx ||J - I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune SST2 Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac


def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] SST2 Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] SST2 Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess SST-2 subset
    train_ds = load_dataset("glue", "sst2", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "sst2", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])
    dev = dev_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian (Autograd JVP version), AD

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings
from contextlib import contextmanager

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    No-op on CPU.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian Deviation (||J - I||) / Hook Utilities
#    Capture each layer's input so we can locally re-run it
#    to estimate JVPs via autograd.
# ========================================================
def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def register_jac_hooks(model):
    """
    Store each RobertaLayer's input (hidden_states) and extra args
    (attention mask, etc.). We'll locally call the layer again to
    compute JVPs with autograd.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra
        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def compute_batch_jacdev_autograd(model, acts, k_probes=2, rademacher=True):
    """
    For each layer ℓ, estimate E_v ||(Jℓ - I)v||^2 using autograd JVP:
      Jv from torch.autograd.functional.jvp, then (J - I)v = Jv - v.
    Average squared norm over k_probes random v.
    Returns {layer_idx: score}.
    """
    from torch.autograd.functional import jvp
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # turn off dropout during the probes

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Avoid efficient SDPA kernels lacking higher-order grads
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1
                    else:
                        v = torch.randn_like(X)
                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd_vec = Jv - v
                    acc += float(jd_vec.pow(2).mean().item())
            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers for next batch
        buf['X'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=6):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    Replace FFN with Identity and use residual passthrough.
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1], reverse=False)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer (Autograd JVP)
            batch_jac = compute_batch_jacdev_autograd(model, activations, k_probes=2, rademacher=True)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/jac_counts[idx]
                     for idx in jac_sums if jac_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx ||J - I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune SST2 Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] SST2 Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] SST2 Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess SST-2 subset
    train_ds = load_dataset("glue", "sst2", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "sst2", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])
    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian FD for CoLA

# Jacobian for CoLA

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Jacobian-Deviation / Hook Utilities (per-layer J ~ I)
# ========================================================
def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input (hidden_states) + the rest of its args
    so we can re-run the layer locally to estimate (J - I)v via finite diff.
    We store:
      activations[idx]['X']   = hidden_states (fp32, detached)
      activations[idx]['args']= tuple of non-input args (kept as-is)
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            # inputs: (hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, ...)
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach().to(torch.float32)
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps=1e-3, k_probes=1, rms_norm_v=True):
    """
    For each layer, estimate ||J - I||^2 via finite-difference JVPs:
       (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    Average squared norms over `k_probes` random v.
    Returns dict: {layer_idx: score}
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout etc.

        try:
            Xf = X  # already fp32
            y0 = _layer_forward_only(layer, Xf, args).to(torch.float32)

            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(Xf)

                if rms_norm_v:
                    # match v's RMS to X's RMS for a well-scaled finite diff
                    vx = v.reshape(v.size(0), -1)
                    xx = Xf.reshape(Xf.size(0), -1)
                    v_rms = vx.std(dim=1, keepdim=True) + 1e-6
                    x_rms = xx.std(dim=1, keepdim=True) + 1e-6
                    v = (v / v_rms.view(-1, 1, 1)) * x_rms.view(-1, 1, 1)

                y_eps = _layer_forward_only(layer, Xf + eps * v, args).to(torch.float32)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes

        finally:
            layer.train(was_training)

        # clear buffers for next batch
        buf['X'] = None
        buf['args'] = None

    return scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||) — preserve first & last
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor  # residual passthrough

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the lowest Jacobian deviation (closest to identity),
    while preserving the first (idx=0) and last (idx=L-1) Transformer blocks.
    We remove the FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    layers = model.roberta.encoder.layer
    L = len(layers)
    preserve = {0, L - 1}

    # Eligible = all scored layers except first/last
    eligible = [(idx, s) for idx, s in jac_scores.items() if idx not in preserve and 0 <= idx < L]
    if not eligible:
        print("[Prune] No eligible layers to prune after preserving first/last.")
        return []

    eligible_sorted = sorted(eligible, key=lambda x: x[1])  # lowest first
    take = max(0, min(num_prune, len(eligible_sorted)))
    prune_idxs = [idx for idx, _ in eligible_sorted[:take]]

    for idx in prune_idxs:
        layer = layers[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()

    print(f"[Prune] Preserved layers: {sorted(list(preserve))}; Pruned layers: {sorted(prune_idxs)}")
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")  # You can switch to CoLA MCC if you prefer
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 scores (finite difference)
            batch_jac = compute_batch_jacdev(model, activations, eps=1e-3, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx] / max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune CoLA Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac


def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores, num_prune=4):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num_prune)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] CoLA Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] CoLA Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess CoLA subset
    train_ds = load_dataset("glue", "cola", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "cola", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])
    dev = dev_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                     batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)

    # Preserve first & last layers during pruning. Adjust num_prune as desired.
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores, num_prune=4)

    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()




In [None]:
# Jacobian for CoLA (Autograd JVP version)

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings
from contextlib import contextmanager  # <-- correct import

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    No-op on CPU.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian-Deviation / Hook Utilities (per-layer J ~ I)
#    Uses autograd JVP, not finite differences.
# ========================================================
def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input (hidden_states) + the rest of its args
    so we can re-run the layer locally and compute JVPs.
      activations[idx]['X']    = hidden_states (detached)
      activations[idx]['args'] = tuple of non-input args (kept as-is)
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra
        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_jacdev_autograd(model, acts, k_probes=2, rademacher=True):
    """
    For each layer ℓ, estimate E_v ||(Jℓ - I)v||^2 using autograd JVP:
        _, Jv = jvp(f, (X,), (v,));  (J - I)v = Jv - v
    Average squared norm over k_probes random v.
    Returns {layer_idx: score}.
    """
    from torch.autograd.functional import jvp
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # turn off dropout for determinism

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Avoid efficient SDPA kernels lacking higher-order grads
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1
                    else:
                        v = torch.randn_like(X)

                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd_vec = Jv - v
                    acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers for next batch
        buf['X'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor  # residual passthrough

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the lowest Jacobian deviation (closest to identity).
    We remove the FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    # CoLA's official metric is MCC, but we'll keep accuracy for simplicity;
    # swap to evaluate.load("glue", "cola") if you want MCC.
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids  = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device)
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer via Autograd JVP
            batch_jac = compute_batch_jacdev_autograd(
                model, activations, k_probes=2, rademacher=True
            )
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        # epoch-averaged ||J - I||^2 per layer
        epoch_jac = {idx: jac_sums[idx] / max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:",
              {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune CoLA Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(
                input_ids=b['input_ids'].to(device),
                attention_mask=b['attention_mask'].to(device),
                labels=b['labels'].to(device)
            )
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] CoLA Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters():  p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device)
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] CoLA Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess CoLA subset
    train_ds = load_dataset("glue", "cola", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "cola", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])
    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian Deviation for QNLI

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Jacobian-Deviation / Hook Utilities (per-layer)
# ========================================================
def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input X and all non-input args.
    We'll locally re-run the layer (with dropout disabled) to estimate JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            x = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = x.detach()
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _layer_forward_only(layer, x, args, use_amp=False):
    # Call the layer and return the hidden states (0th output)
    if use_amp:
        with autocast():
            out = layer(x, *args)
    else:
        out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps_fp32=1e-3, eps_fp16=1e-2, k_probes=1):
    """
    For each layer ℓ, estimate  ||J_ℓ - I||_F^2 via finite-difference JVPs:
        (J - I)v  ≈  (f(x + eps v) - f(x))/eps - v
    Average squared norm over k_probes random v and all elements.
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable finite-diff

        try:
            # Choose epsilon based on dtype stability
            dtype = X.dtype
            use_amp = dtype in (torch.float16, torch.bfloat16)
            eps = eps_fp16 if use_amp else eps_fp32

            # Base output with dropout off, matching dtype
            x0 = X
            y0 = _layer_forward_only(layer, x0, args, use_amp=use_amp)

            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(x0)

                # Optional normalization to keep step scale reasonable
                v = v / (v.std() + 1e-6)

                y_eps = _layer_forward_only(layer, x0 + eps * v, args, use_amp=use_amp)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes

        finally:
            layer.train(was_training)

        # Clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers whose mappings are closest to identity (lowest ||J-I||).
    We remove the FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['question'],
               examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer
            batch_jac = compute_batch_jacdev(model, activations, eps_fp32=1e-3, eps_fp16=1e-2, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QNLI Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac


def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess QNLI
    train_ds = load_dataset("glue", "qnli", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "qnli", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    def preprocess(examples):
        return tokenizer(examples["question"],
                         examples["sentence"],
                         truncation=True,
                         padding='max_length',
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["question", "sentence", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["question", "sentence", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()




In [None]:
# Jacobian Deviation for QNLI (Autograd JVP version)
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings
from contextlib import contextmanager  # <-- correct import

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    No-op on CPU.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian-Deviation / Hook Utilities (per-layer)
# ========================================================
def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input X and all non-input args.
    We'll locally re-run the layer (with dropout disabled) to estimate JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            x = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = x.detach()
            activations[idx]['args'] = extra
        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_jacdev_autograd(model, acts, k_probes=2, rademacher=True):
    """
    For each layer ℓ, estimate  E_v ||(J_ℓ - I)v||^2  using autograd JVP:
        _, Jv = jvp(f, (X,), (v,));  (J - I)v = Jv - v
    Average squared norm over k_probes random v and all elements.
    Returns dict: {layer_idx: score}
    """
    from torch.autograd.functional import jvp
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X, args = buf['X'], buf['args']
        if X is None:
            continue

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable JVP

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Avoid efficient SDPA kernels that lack higher-order grads
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1 probes
                    else:
                        v = torch.randn_like(X)
                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd_vec = Jv - v
                    acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # Clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||) — preserves first & last
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers whose mappings are closest to identity (lowest ||J-I||) while
    preserving the first (idx=0) and last (idx=L-1) Transformer blocks.

    Implementation details:
      • Exclude idx in {0, L-1} from eligibility
      • If jac_scores is missing some eligible layers, skip them safely
      • If num_prune > eligible_count, prune as many as possible
    """
    layers = model.roberta.encoder.layer
    L = len(layers)
    preserve = {0, L - 1}

    # Filter to eligible layers only
    eligible = [(idx, score) for idx, score in jac_scores.items() if idx not in preserve and 0 <= idx < L]
    if not eligible:
        print("[Prune] No eligible layers to prune (after preserving first/last).")
        return []

    # Sort by lowest ||J-I|| first
    eligible_sorted = sorted(eligible, key=lambda x: x[1])

    # Take up to num_prune indices
    prune_idxs = [idx for idx, _ in eligible_sorted[:max(0, min(num_prune, len(eligible_sorted)))]]

    # Apply structural pruning to selected layers
    for idx in prune_idxs:
        layer = layers[idx]
        # Remove FFN contribution by bypassing through the residual
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()

    print(f"[Prune] Preserved layers: {sorted(list(preserve))}; Pruned layers: {sorted(prune_idxs)}")
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['question'],
               examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer (Autograd JVP)
            batch_jac = compute_batch_jacdev_autograd(model, activations, k_probes=2, rademacher=True)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QNLI Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores, num_prune=4):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess QNLI
    train_ds = load_dataset("glue", "qnli", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "qnli", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    def preprocess(examples):
        return tokenizer(examples["question"],
                         examples["sentence"],
                         truncation=True,
                         padding='max_length',
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["question", "sentence", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["question", "sentence", "idx"])

    # Already padded to max_length in preprocess; default collator is fine
    collator     = DataCollatorWithPadding(tokenizer, padding="max_length")

    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)

    # >>> Preserve first & last layers during pruning; choose how many to prune <<<
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores, num_prune=4)

    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian for QQP

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Jacobian Deviation / Hook Utilities (per-layer)
#    We capture each RobertaLayer's input (and extra args),
#    then locally eval f(x+eps v) to estimate (J - I)v.
# ========================================================
def _flatten_tokens(t):
    # Not used in Jacobian path, but kept for parity if needed later
    if t.dim() == 3:
        b, s, h = t.shape
        return t.reshape(b * s, h)
    elif t.dim() == 2:
        return t
    return t.view(t.size(0), -1)

def register_jac_hooks(model):
    """
    Capture each RobertaLayer's block input (pre) and the extra forward args.
    We'll re-call that layer to estimate Jacobian-vector products.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            # inputs: (hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, ...)
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            # store as fp32 to avoid AMP/half issues during finite diff
            activations[idx]['X'] = xs.detach().to(torch.float32)
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps=1e-3, k_probes=1):
    """
    Estimate per-layer ||J - I||_F^2 via finite differences:
      (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    Average squared norm over probes and all elements.
    Returns dict: {layer_idx: score}
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        layer = layers[idx]
        # Ensure deterministic behavior during probe (disable dropout)
        was_training = layer.training
        layer.eval()
        try:
            # Base forward at fp32 (pairs with eps-perturbed forward)
            y0 = _layer_forward_only(layer, X, args).to(torch.float32)

            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(X)                         # random probe
                y_eps = _layer_forward_only(layer, X + eps * v, args).to(torch.float32)
                jd_vec = (y_eps - y0) / eps - v                # (J - I)v
                acc += float(jd_vec.pow(2).mean().item())      # mean over all elts

            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J - I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor  # residual passthrough

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the lowest Jacobian deviation (closest to identity).
    Implementation: drop FFN by Identity and route residual through.
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['question1'],
               examples['question2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=2).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 for all layers we saw this step
            batch_jac = compute_batch_jacdev(model, activations, eps=1e-3, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QQP Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac


def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}")
    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "qqp", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "qqp", split="validation")

    def preprocess(examples):
        return tokenizer(examples["question1"], examples["question2"],
                         truncation=True, padding="max_length", max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["question1", "question2", "idx"])
    dev   = dev_ds.map(preprocess, batched=True)\
                  .rename_column("label", "labels")\
                  .remove_columns(["question1", "question2", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()




In [None]:
# Jacobian for QQP (Autograd JVP version)

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings
from contextlib import contextmanager  # <-- ensure correct import

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    No-op on CPU.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian Deviation / Hook Utilities (per-layer)
#    We capture each RobertaLayer's input (and extra args),
#    then locally evaluate JVP to estimate (J - I)v.
# ========================================================
def register_jac_hooks(model):
    """
    Capture each RobertaLayer's block input (pre) and the extra forward args.
    We'll re-call that layer to compute autograd JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            # inputs: (hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, ...)
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_jacdev_autograd(model, acts, k_probes=2, rademacher=True):
    """
    Estimate per-layer E_v ||(J - I)v||^2 via autograd JVP:
        _, Jv = jvp(f, (X,), (v,));  (J - I)v = Jv - v
    Average squared norm over probes and all elements.
    Returns dict: {layer_idx: score}
    """
    from torch.autograd.functional import jvp
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; args = buf['args']
        if X is None:
            continue

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable JVP

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Avoid efficient SDPA kernels that lack higher-order grads
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1 probes
                    else:
                        v = torch.randn_like(X)
                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd_vec = Jv - v
                    acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J - I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor  # residual passthrough

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the lowest Jacobian deviation (closest to identity).
    Implementation: drop FFN by Identity and route residual through.
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['question1'],
               examples['question2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=2).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 for all layers we saw this step (Autograd JVP)
            batch_jac = compute_batch_jacdev_autograd(model, activations, k_probes=2, rademacher=True)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QQP Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}")
    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "qqp", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "qqp", split="validation")

    def preprocess(examples):
        return tokenizer(examples["question1"], examples["question2"],
                         truncation=True, padding="max_length", max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["question1", "question2", "idx"])
    dev   = dev_ds.map(preprocess, batched=True)\
                  .rename_column("label", "labels")\
                  .remove_columns(["question1", "question2", "idx"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian for RTE


# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Jacobian-Deviation Hooks/Scoring
#    Compare each layer's mapping f_ℓ via ||J_ℓ - I|| (Hutchinson-style)
# ========================================================

def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input (X), output (Y), and non-input args.
    We'll call the layer again locally to estimate JVPs via finite differences.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'Y': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra

        def post_hook(module, inputs, output, idx=i):
            y = output[0] if isinstance(output, tuple) else output
            activations[idx]['Y'] = y.detach()

        hooks.append(layer.register_forward_pre_hook(pre_hook))
        hooks.append(layer.register_forward_hook(post_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _layer_forward_only(layer, x, args):
    out = layer(x, *args)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def compute_batch_jacdev(model, acts, eps=1e-3, k_probes=1):
    """
    For each layer, estimate ||J - I||_F^2 via finite-difference JVPs:
      (J - I)v ≈ (f(x + eps*v) - f(x))/eps - v
    Average squared norms over k_probes and all elements.
    Returns dict: {layer_idx: score}
    """
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X = buf['X']; Y = buf['Y']; args = buf['args']
        if X is None or Y is None:
            continue

        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable finite-diff
        try:
            # Base output with dropout off to pair with eps-perturbed run
            y0 = _layer_forward_only(layer, X, args)

            acc = 0.0
            for _ in range(k_probes):
                v = torch.randn_like(X)
                y_eps = _layer_forward_only(layer, X + eps * v, args)
                jd_vec = (y_eps - y0) / eps - v
                acc += float(jd_vec.pow(2).mean().item())
            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # Clear buffers so we don't reuse stale tensors
        buf['X'] = None
        buf['Y'] = None
        buf['args'] = None

    return scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    Remove FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch Jacobian deviation scores
            batch_jac = compute_batch_jacdev(model, activations, eps=1e-3, k_probes=1)
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx]/max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:", {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune RTE Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac


def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess RTE
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "rte", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "rte", split="validation")

    def preprocess(examples):
        return tokenizer(examples["sentence1"],
                         examples["sentence2"],
                         truncation=True,
                         padding="max_length",
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence1", "sentence2", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence1", "sentence2", "idx"])

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev,  batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian for RTE (Autograd JVP version)

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings
from contextlib import contextmanager  # <-- correct import

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1.1) Utility: force 'math' attention kernels during JVP
# ========================================================
@contextmanager
def force_math_sdp():
    """
    Force 'math' scaled-dot-product attention so higher-order grads/JVP work.
    No-op on CPU.
    """
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ========================================================
# 2) Jacobian-Deviation Hooks/Scoring (Autograd JVP)
#    Compare each layer's mapping f_ℓ via ||J_ℓ - I|| using Hutchinson probes
# ========================================================
def register_jac_hooks(model):
    """
    Capture each RobertaLayer's input (X) and non-input args so we can
    locally re-run the layer to compute JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'args': None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            xs = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]['X'] = xs.detach()
            activations[idx]['args'] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_jacdev_autograd(model, acts, k_probes=2, rademacher=True):
    """
    For each layer ℓ, estimate E_v ||(J_ℓ - I)v||^2 via autograd JVP:
       _, Jv = jvp(f, (X,), (v,));  (J - I)v = Jv - v
    Average squared norms over probes and all elements.
    Returns dict: {layer_idx: score}
    """
    from torch.autograd.functional import jvp
    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in acts.items():
        X, args = buf['X'], buf['args']
        if X is None:
            continue

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable JVP

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            acc = 0.0
            # Avoid efficient SDPA kernels that lack higher-order grads
            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1 probes
                    else:
                        v = torch.randn_like(X)
                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    jd_vec = Jv - v
                    acc += float(jd_vec.pow(2).mean().item())

            scores[idx] = acc / k_probes
        finally:
            layer.train(was_training)

        # clear buffers
        buf['X'] = None
        buf['args'] = None

    return scores

# ========================================================
# 3) Pruning Utilities with SkipFF (prune low-||J-I||)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        # residual passthrough
        return input_tensor

def prune_jac_layers(model, jac_scores, num_prune=4):
    """
    Prune layers with the *lowest* Jacobian deviation (closest to identity).
    Remove FFN by replacing intermediate.dense with Identity
    and the output block with SkipFF (residual passthrough).
    """
    sorted_layers = sorted(jac_scores.items(), key=lambda x: x[1])  # lowest first
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=128):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids  = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (Jacobian-Deviation scoring)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & Jacobian-Deviation Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jac_hooks(model)
    last_jac = None

    for epoch in range(6):
        jac_sums, jac_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device)
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # Per-batch ||J - I||^2 per layer (Autograd JVP)
            batch_jac = compute_batch_jacdev_autograd(
                model, activations, k_probes=2, rademacher=True
            )
            for idx, v in batch_jac.items():
                jac_sums[idx]   += v
                jac_counts[idx] += 1

        epoch_jac = {idx: jac_sums[idx] / max(1, jac_counts[idx]) for idx in jac_sums}
        print(f"[Epoch {epoch+1}] approx ||J-I||^2:",
              {k: round(v, 6) for k, v in epoch_jac.items()})
        last_jac = epoch_jac

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune RTE Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_jac

def prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores):
    print("=== Stage 2: Prune (Low-||J-I||) & Finetuning ===")
    prune_idxs = prune_jac_layers(model, jac_scores, num_prune=num)
    print("Pruned layers (lowest ||J-I||):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(
                input_ids=b['input_ids'].to(device),
                attention_mask=b['attention_mask'].to(device),
                labels=b['labels'].to(device)
            )
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters():  p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=b['input_ids'].to(device),
                    attention_mask=b['attention_mask'].to(device),
                    labels=b['labels'].to(device)
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess RTE
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "rte", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "rte", split="validation")

    def preprocess(examples):
        return tokenizer(
            examples["sentence1"], examples["sentence2"],
            truncation=True, padding="max_length", max_length=128
        )

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence1", "sentence2", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence1", "sentence2", "idx"])

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev,  batch_size=16, shuffle=False, collate_fn=collator)

    model, jac_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jac_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian for STS-B

import numpy as np
import random
import math
import warnings

# ── Robust NumPy 'copy' kwarg patch (NumPy 1.x and 2.x; idempotent; no recursion) ──
# Many libs call np.array(..., copy=False). In NumPy 2.0 this is *strict* and may raise
# ValueError if a copy is required. We relax this by removing copy=False (but honor copy=True).
try:
    from numpy.core.multiarray import array as _np_array_c
except Exception:
    _np_array_c = None

if _np_array_c is not None and not getattr(np, "_array_copy_patched_relaxed", False):
    def _np_array_relaxed(obj, *args, **kwargs):
        # If caller asked for copy=False (strict in NumPy 2.0), drop it to allow fallback
        if "copy" in kwargs and kwargs["copy"] is False:
            kwargs = dict(kwargs)
            kwargs.pop("copy", None)
            return _np_array_c(obj, *args, **kwargs)
        try:
            # Normal path
            return _np_array_c(obj, *args, **kwargs)
        except ValueError as e:
            # If strict 'copy' triggered a failure, retry without it
            if "copy" in kwargs:
                kwargs = dict(kwargs)
                kwargs.pop("copy", None)
                return _np_array_c(obj, *args, **kwargs)
            raise
        except TypeError:
            # NumPy < 2.0 may not accept 'copy' kwarg—remove it if present
            kwargs = dict(kwargs)
            kwargs.pop("copy", None)
            return _np_array_c(obj, *args, **kwargs)

    # Reset np.array to C-level impl (break any prior patches), then wrap
    np.array = _np_array_relaxed
    np._array_copy_patched_relaxed = True
# ───────────────────────────────────────────────────────────────────────────

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
import evaluate

from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from datasets import load_dataset

from collections import defaultdict

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ─── 1) Jacobian-Deviation (JD) Hooks ──────────────────────────────────────
def register_jd_hooks(model):
    """
    Capture (X_i, Y_{i+1}) at each adjacent pair's output.dense.
    IMPORTANT: do NOT detach — we need the graph for autograd.grad.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {'X': None, 'Y': None} for i in range(len(layers)-1)}
    hooks = []

    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['X'] = out  # keep graph
        def hook_y(module, inp, out, idx=i):
            activations[idx]['Y'] = out  # keep graph
        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def _numel(t):
    return int(t.numel())

def _flatten(t):
    return t.reshape(-1)

def _safe_item(x):
    return float(torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6).item())


def compute_batch_jd(
    activations,
    num_probes: int = 1,
    sample_pairs: int | None = None
):
    """
    JD(J) = ||J - I||_F^2 / D with Hutchinson:
      ||J||_F^2  ≈  E[ ||J^T z||^2 ]
      tr(J)      ≈  E[ z^T J z ] = E[ (J^T z)·z ]
      ||I||_F^2  =  D  (D = dim of X)

    We compute VJP g = J^T z via autograd.grad(outputs=Y, inputs=X, grad_outputs=z).
    Do the *reductions* in float32 to avoid AMP/FP16 overflow.
    """
    all_idxs = [i for i, buf in activations.items() if buf['X'] is not None and buf['Y'] is not None]
    if sample_pairs is not None and sample_pairs < len(all_idxs) and all_idxs:
        with torch.no_grad():
            device = activations[all_idxs[0]]['X'].device
            perm = torch.randperm(len(all_idxs), device=device)
        chosen = [all_idxs[int(i)] for i in perm[:sample_pairs]]
    else:
        chosen = all_idxs

    scores = {}
    if not chosen:
        return scores

    for idx in chosen:
        X = activations[idx]['X']
        Y = activations[idx]['Y']

        # Must match shape and require grads
        if (
            X is None or Y is None or X.shape != Y.shape
            or (not getattr(X, "requires_grad", False))
            or (not getattr(Y, "requires_grad", False))
        ):
            scores[idx] = 0.0
            continue

        D = int(X.numel())
        D32 = torch.tensor(D, device=Y.device, dtype=torch.float32)

        jd_sum32 = torch.zeros((), device=Y.device, dtype=torch.float32)

        for _ in range(num_probes):
            # grad_outputs must match outputs dtype; keep AMP-friendly here
            z = torch.randn_like(Y)

            g = torch.autograd.grad(
                outputs=Y,
                inputs=X,
                grad_outputs=z,
                retain_graph=True,
                allow_unused=True,
                create_graph=False
            )[0]

            if g is None:
                g = torch.zeros_like(X)

            # Do the heavy reductions in float32 to prevent overflow
            with torch.cuda.amp.autocast(enabled=False):
                g_flat32 = g.reshape(-1).float()
                z_flat32 = z.reshape(-1).float()

                # Clean NaNs/infs defensively (rare, but keeps scores finite)
                g_flat32 = torch.nan_to_num(g_flat32, nan=0.0, posinf=1e6, neginf=-1e6)
                z_flat32 = torch.nan_to_num(z_flat32, nan=0.0, posinf=1e6, neginf=-1e6)

                term_normJ = torch.dot(g_flat32, g_flat32)      # ||J^T z||^2
                term_trJ   = torch.dot(g_flat32, z_flat32)      # z^T J z

                jd_probe32 = term_normJ - 2.0 * term_trJ + D32  # float32 math
                # Final safety: clamp extreme explosions
                jd_probe32 = torch.clamp(jd_probe32, min=-1e12, max=1e12)

            jd_sum32 += jd_probe32

        jd_mean32 = jd_sum32 / max(1, num_probes)
        jd32 = jd_mean32 / D32  # normalize by dimensionality

        # Ensure finite scalar
        jd32 = torch.nan_to_num(jd32, nan=0.0, posinf=1e6, neginf=-1e6)
        scores[idx] = float(jd32.item())

        # Clear references so the graph can be freed after backward
        activations[idx]['X'] = None
        activations[idx]['Y'] = None

    # Also clear any non-chosen to avoid leaking graph to the next step
    for i in activations:
        activations[i]['X'] = None
        activations[i]['Y'] = None

    return scores




# ─── 2) Pruning Utilities (prune low-JD → downstream layer) ────────────────
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor  # residual passthrough



def prune_jd_layers(model, jd_scores, num_prune=4):
    """
    Sort ascending by JD (closest to identity first), then prune the *downstream*
    layer (j = i+1) — but NEVER prune the first two (0,1) or last two (L-2,L-1)
    encoder layers.
    """
    L = len(model.roberta.encoder.layer)
    # Need at least 5 layers to have any room after reserving 2 + 2
    assert L >= 5, "Need at least 5 layers to preserve first/last two and still prune."

    # Candidate downstream layers are j in [2, L-3]
    allowed = set(range(2, L - 2))
    if not allowed:
        return []

    # Sort pairs by JD; consider downstream j = i+1
    candidate_pairs = sorted(jd_scores.items(), key=lambda x: x[1])

    prune_idxs = []
    for i, _score in candidate_pairs:
        j = i + 1
        if j in allowed and j not in prune_idxs:
            prune_idxs.append(j)
        if len(prune_idxs) >= min(num_prune, len(allowed)):
            break

    # Apply pruning
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()

    return prune_idxs




# ─── 3) LoRA Modules ──────────────────────────────────────────────────────
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ─── 4) STS-B Evaluation ───────────────────────────────────────────────────
def evaluate_stsb(model, dataloader, device):
    model.eval()
    metric = evaluate.load("glue", "stsb")
    preds, refs = [], []
    with torch.no_grad():
        for batch in dataloader:
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            p = out.logits.squeeze(-1).cpu().tolist()
            preds.extend(p if isinstance(p, list) else [p])
            r = batch["labels"].cpu().tolist()
            for x in r:
                if isinstance(x, (list, tuple, np.ndarray)):
                    refs.append(float(x[0]))
                else:
                    refs.append(float(x))
    return metric.compute(predictions=preds, references=refs)

# ─── 5) Training Stages (JD-based scoring) ─────────────────────────────────
def full_finetuning(train_loader, dev_loader, device):
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=1
    ).to(device)

    # Keep checkpointing OFF while computing JD (hooks need grad paths)
    # model.gradient_checkpointing_enable()

    opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_jd_hooks(model)
    last_jd = None

    for epoch in range(6):
        jd_sums, jd_counts = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with torch.set_grad_enabled(True), autocast():
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
                # ---- JD BEFORE backward (graph needed) ----
                batch_jd = compute_batch_jd(activations, num_probes=1, sample_pairs=4)

            for idx, v in batch_jd.items():
                jd_sums[idx]   += v
                jd_counts[idx] += 1

            scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        epoch_jd = {idx: jd_sums[idx]/max(1, jd_counts[idx]) for idx in jd_sums}
        print(f"[Epoch {epoch+1}] approx Jacobian Deviation:", {k: round(v, 6) for k, v in epoch_jd.items()})
        last_jd = epoch_jd

    metrics = evaluate_stsb(model, dev_loader, device)
    print(f"STS-B Pearson: {metrics['pearson']:.4f}, Spearman: {metrics['spearmanr']:.4f}")
    remove_hooks(hooks)
    return model, last_jd

def prune_and_finetuning(model, train_loader, dev_loader, device, jd_scores):
    prune_idxs = prune_jd_layers(model, jd_scores, num_prune=num)
    print("Pruned layers (lowest-JD pairs → pruned downstream):", prune_idxs)

    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            out.loss.backward()
            opt.step()
            sched.step()

        metrics = evaluate_stsb(model, dev_loader, device)
        print(f"[Prune Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}")
    return model

def lora_only_finetuning(model, train_loader, dev_loader, device):
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.AdamW(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
            scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        metrics = evaluate_stsb(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}")

# ─── 6) Main Entrypoint ────────────────────────────────────────────────────
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "stsb", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "stsb", split="validation")

    def preprocess(ex):
        return tokenizer(
            ex["sentence1"], ex["sentence2"],
            truncation=True, padding="max_length", max_length=128
        )

    train_ds = train_ds.map(preprocess, batched=True)
    dev_ds   = dev_ds.map(preprocess, batched=True)

    # Cast labels to flat float32
    train_ds = train_ds.map(lambda x: {"labels": float(x["label"])}, batched=False)
    dev_ds   = dev_ds.map(lambda x: {"labels": float(x["label"])}, batched=False)

    train_ds = train_ds.remove_columns(["sentence1", "sentence2", "label", "idx"])
    dev_ds   = dev_ds.remove_columns(["sentence1", "sentence2", "label", "idx"])

    train_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    dev_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev_ds,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jd_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jd_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# Jacobian using Autograde JVP

# Jacobian → Energy Distance for STS-B (Autograd JVP version)

import numpy as np
import random
import math
import warnings

# ── Robust NumPy 'copy' kwarg patch (NumPy 1.x and 2.x; idempotent; no recursion) ──
try:
    from numpy.core.multiarray import array as _np_array_c
except Exception:
    _np_array_c = None

if _np_array_c is not None and not getattr(np, "_array_copy_patched_relaxed", False):
    def _np_array_relaxed(obj, *args, **kwargs):
        if "copy" in kwargs and kwargs["copy"] is False:
            kwargs = dict(kwargs); kwargs.pop("copy", None)
            return _np_array_c(obj, *args, **kwargs)
        try:
            return _np_array_c(obj, *args, **kwargs)
        except ValueError:
            if "copy" in kwargs:
                kwargs = dict(kwargs); kwargs.pop("copy", None)
                return _np_array_c(obj, *args, **kwargs)
            raise
        except TypeError:
            kwargs = dict(kwargs); kwargs.pop("copy", None)
            return _np_array_c(obj, *args, **kwargs)

    np.array = _np_array_relaxed
    np._array_copy_patched_relaxed = True
# ───────────────────────────────────────────────────────────────────────────

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
import evaluate

from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from datasets import load_dataset
from collections import defaultdict
from contextlib import contextmanager

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ─── 0) Make JVP work with SDPA by forcing math kernels (CUDA only) ─────────
@contextmanager
def force_math_sdp():
    if torch.cuda.is_available():
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        ):
            yield
    else:
        yield

# ─── 1) JVP hooks: capture per-layer inputs so we can re-run locally ───────
def register_jvp_hooks(model):
    """
    Store each RobertaLayer's input hidden_states (X) and non-input args
    (attention_mask, etc.) so we can locally call the layer f and compute JVPs.
    """
    layers = model.roberta.encoder.layer
    activations = {i: {"X": None, "args": None} for i in range(len(layers))}
    hooks = []

    for i, layer in enumerate(layers):
        def pre_hook(module, inputs, idx=i):
            x = inputs[0]
            extra = tuple(inputs[1:]) if len(inputs) > 1 else tuple()
            activations[idx]["X"] = x.detach()                 # fresh capture each batch
            activations[idx]["args"] = extra

        hooks.append(layer.register_forward_pre_hook(pre_hook))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ─── 2) Helpers for ED over vector samples ─────────────────────────────────
@torch.no_grad()
def _to_samples(x):
    """
    Flatten (B, S, H) → (B*S, H) and cast to float32 for stable distance calcs.
    If x is (N, H) already, returns it unchanged (float32).
    """
    x = x.to(torch.float32)
    if x.dim() == 3:   # (B, S, H)
        b, s, h = x.shape
        return x.reshape(b * s, h)
    return x  # assume (N, H)

@torch.no_grad()
def _energy_distance(a, b, max_samples=2048):
    """
    Biased energy distance:
        ED = 2 E||A-B|| - E||A-A'|| - E||B-B'||
    a, b: (Na, D) and (Nb, D) float32 tensors.
    """
    if a.numel() == 0 or b.numel() == 0:
        return 0.0

    Na, Nb = a.shape[0], b.shape[0]
    if Na > max_samples:
        a = a[torch.randperm(Na, device=a.device)[:max_samples]]
        Na = a.shape[0]
    if Nb > max_samples:
        b = b[torch.randperm(Nb, device=b.device)[:max_samples]]
        Nb = b.shape[0]

    dab = torch.cdist(a, b, p=2)                       # (Na, Nb)
    daa = torch.cdist(a, a, p=2) if Na > 1 else torch.zeros((), device=a.device)
    dbb = torch.cdist(b, b, p=2) if Nb > 1 else torch.zeros((), device=b.device)

    term_ab = dab.mean()
    term_aa = daa.mean() if Na > 1 else torch.tensor(0.0, device=a.device)
    term_bb = dbb.mean() if Nb > 1 else torch.tensor(0.0, device=b.device)

    ed = 2.0 * term_ab - term_aa - term_bb
    ed = torch.nan_to_num(ed, nan=0.0, posinf=1e6, neginf=-1e6)
    return float(ed.item())

# ─── 3) Core scorer: ED between Jv and v (i.e., || distribution gap ||) ────
def compute_batch_jed_autograd(
    model,
    activations,
    k_probes: int = 2,
    rademacher: bool = True,
    max_tokens: int | None = 2048,
    max_samples: int = 2048,
):
    """
    Jacobian→Energy Distance per layer:
      For layer ℓ with input X, draw k probe vectors v ~ {±1} (or N(0,I)).
      Compute Jv via autograd JVP, gather samples A = Jv and B = v over tokens
      (and probes), and return ED(A, B). Small ED ⇒ linearization close to I.
    Returns: {layer_idx: ed_scalar}
    """
    from torch.autograd.functional import jvp

    layers = model.roberta.encoder.layer
    scores = {}

    for idx, buf in activations.items():
        X, args = buf["X"], buf["args"]
        if X is None:
            continue

        # Optional token subsampling to keep JVP compute bounded
        if X.dim() == 3 and max_tokens is not None:
            b, s, h = X.shape
            if b * s > max_tokens:
                flat = X.reshape(b * s, h)
                sel = torch.randperm(flat.size(0), device=flat.device)[:max_tokens]
                X = flat[sel].unsqueeze(0)  # fake batch of 1 with max_tokens seq len
                # best-effort attention mask recreation if present
                if len(args) > 0 and isinstance(args[0], torch.Tensor):
                    attn = torch.ones((1, X.size(1)), dtype=args[0].dtype, device=X.device)
                    args = (attn,) + args[1:]

        X = X.detach().requires_grad_(True)
        layer = layers[idx]
        was_training = layer.training
        layer.eval()  # disable dropout for stable JVP

        try:
            def f(inp):
                out = layer(inp, *args)
                return out[0] if isinstance(out, tuple) else out

            # Collect samples across probes
            A_list = []  # Jv
            B_list = []  # v

            with force_math_sdp():
                for _ in range(k_probes):
                    if rademacher:
                        v = torch.empty_like(X).bernoulli_(0.5).mul_(2).sub_(1)  # ±1
                    else:
                        v = torch.randn_like(X)

                    _, Jv = jvp(f, (X,), (v,), create_graph=False, strict=True)
                    # Flatten to (N, H) sample sets
                    A_list.append(_to_samples(Jv))
                    B_list.append(_to_samples(v))

            A = torch.cat(A_list, dim=0) if len(A_list) else torch.empty(0, X.size(-1), device=X.device)
            B = torch.cat(B_list, dim=0) if len(B_list) else torch.empty(0, X.size(-1), device=X.device)
            scores[idx] = _energy_distance(A, B, max_samples=max_samples)

        finally:
            layer.train(was_training)

        # Clear to avoid stale tensors
        buf["X"] = None
        buf["args"] = None

    return scores

# ─── 4) Pruning: remove layers with smallest ED(Jv, v) ─────────────────────
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor  # residual passthrough

def prune_jed_layers(model, jed_scores, num_prune=4):
    """
    Sort ascending by ED(Jv, v) (closest-to-identity layers first) and prune them.
    We remove the FFN by replacing intermediate.dense with Identity and set
    the output block to a residual passthrough.
    """
    sorted_layers = sorted(jed_scores.items(), key=lambda x: x[1])
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ─── 5) LoRA Modules (unchanged) ───────────────────────────────────────────
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, "dense"):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ─── 6) STS-B Evaluation (Pearson/Spearman) ────────────────────────────────
def evaluate_stsb(model, dataloader, device):
    model.eval()
    metric = evaluate.load("glue", "stsb")
    preds, refs = [], []
    with torch.no_grad():
        for batch in dataloader:
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            p = out.logits.squeeze(-1).cpu().tolist()
            preds.extend(p if isinstance(p, list) else [p])
            r = batch["labels"].cpu().tolist()
            refs.extend([float(x[0]) if isinstance(x, (list, tuple, np.ndarray)) else float(x) for x in r])
    return metric.compute(predictions=preds, references=refs)

# ─── 7) Training Stages (use JVP-based ED for scoring) ─────────────────────
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & JVP-ED Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=1).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader) * 6)
    scaler = GradScaler()

    hooks, activations = register_jvp_hooks(model)
    last_jed = None

    for epoch in range(6):
        jed_sums, jed_counts = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )

            # JVP-based Energy Distance per layer (before backward)
            batch_jed = compute_batch_jed_autograd(
                model, activations,
                k_probes=2, rademacher=True,
                max_tokens=2048, max_samples=2048
            )
            for idx, v in batch_jed.items():
                jed_sums[idx]   += v
                jed_counts[idx] += 1

            scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        epoch_jed = {idx: jed_sums[idx] / max(1, jed_counts[idx]) for idx in jed_sums}
        print(f"[Epoch {epoch+1}] JVP-ED:", {k: round(v, 6) for k, v in epoch_jed.items()})
        last_jed = epoch_jed

    metrics = evaluate_stsb(model, dev_loader, device)
    print(f"STS-B Pearson: {metrics['pearson']:.4f}, Spearman: {metrics['spearmanr']:.4f}")
    remove_hooks(hooks)
    return model, last_jed

def prune_and_finetuning(model, train_loader, dev_loader, device, jed_scores):
    print("=== Stage 2: Prune (Lowest JVP-ED) & Finetuning ===")
    prune_idxs = prune_jed_layers(model, jed_scores, num_prune=num)
    print("Pruned layers (lowest JVP-ED):", prune_idxs)

    opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader) * 3)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad(set_to_none=True)
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            out.loss.backward()
            opt.step()
            sched.step()

        metrics = evaluate_stsb(model, dev_loader, device)
        print(f"[Prune Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}")
    return model

def lora_only_finetuning(model, train_loader, dev_loader, device):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model)
    for p in model.roberta.parameters():   p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.AdamW(
        list(model.classifier.parameters()) +
        [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader) * 6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for batch in train_loader:
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
            scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        metrics = evaluate_stsb(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}")

# ─── 8) Main Entrypoint ────────────────────────────────────────────────────
def main():
    seed = 42
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "stsb", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "stsb", split="validation")

    def preprocess(ex):
        return tokenizer(
            ex["sentence1"], ex["sentence2"],
            truncation=True, padding="max_length", max_length=128
        )

    train_ds = train_ds.map(preprocess, batched=True)
    dev_ds   = dev_ds.map(preprocess, batched=True)

    # Cast labels to flat float32
    train_ds = train_ds.map(lambda x: {"labels": float(x["label"])}, batched=False)
    dev_ds   = dev_ds.map(lambda x: {"labels": float(x["label"])}, batched=False)

    # Keep only model columns
    train_ds = train_ds.remove_columns(["sentence1", "sentence2", "label", "idx"])
    dev_ds   = dev_ds.remove_columns(["sentence1", "sentence2", "label", "idx"])

    train_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    dev_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev_ds,   batch_size=16, shuffle=False, collate_fn=collator)

    model, jed_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, jed_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()




