In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate


Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
Collecting datasets==2.18.0
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.18.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import warnings
import math

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Conditional ER Hook Utilities ---
def register_conditional_er_hooks(model):
    enc_layers = model.encoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    enc_hooks = []
    for i, layer in enumerate(enc_layers):
        def hook_fn_enc(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            enc_acts[idx] = hs.detach()
        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))
    dec_layers = model.decoder.block
    dec_acts = {i: None for i in range(len(dec_layers))}
    dec_hooks = []
    for i, layer in enumerate(dec_layers):
        def hook_fn_dec(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            dec_acts[idx] = hs.detach()
        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))
    cross_acts = {i: None for i in range(len(dec_layers))}
    cross_hooks = []
    for i, block in enumerate(dec_layers):
        def hook_fn_cross(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            cross_acts[idx] = hs.detach()
        cross_attn = block.layer[1]
        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)

def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks:
            h.remove()

def compute_conditional_batch_entropy(prev_acts, curr_acts):
    er_scores = {}
    for i in range(len(curr_acts) - 1):
        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]
        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]
        if (
            prev_X is not None and prev_Y is not None and
            curr_X is not None and curr_Y is not None and
            prev_X.shape == curr_X.shape and
            prev_Y.shape == curr_Y.shape
        ):
            B = curr_X.size(0)
            dX = (curr_X - prev_X).view(B, -1)
            dY = (curr_Y - prev_Y).view(B, -1)
            # Only compute if shapes are valid and not empty
            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):
                cos_squares = [
                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()
                    for j in range(1, B)
                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())
                ]
                if cos_squares:
                    er = sum(cos_squares) / len(cos_squares)
                    if not (math.isnan(er) or math.isinf(er)):
                        er_scores[i] = er
    return er_scores


# --- 2. Pruning Utilities ---
class SkipBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None, None, None, None)

def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]
    for idx in prune_idxs:
        blocks[idx] = SkipBlock(hidden_size)
    return prune_idxs

# --- 3. Data Processing ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds)

def evaluate_model(model, dl, tokenizer, device, label_texts):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# --- 4. Training Loops ---

def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):
    print("=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)
    last_enc_er, last_dec_er, last_cross_er = None, None, None

    for epoch in range(6):
        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)
        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)
        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)
        model.train()
        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            if prev_enc_acts is not None:
                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)
                for idx, v in enc_batch_er.items():
                    enc_er_sums[idx] += v
                    enc_er_counts[idx] += 1
            if prev_dec_acts is not None:
                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)
                for idx, v in dec_batch_er.items():
                    dec_er_sums[idx] += v
                    dec_er_counts[idx] += 1
            if prev_cross_acts is not None:
                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)
                for idx, v in cross_batch_er.items():
                    cross_er_sums[idx] += v
                    cross_er_counts[idx] += 1
            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}
            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}
            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}
        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}
        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}
        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}")
        print(f"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}")
        print(f"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}")
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])
    return model, last_enc_er, last_dec_er, last_cross_er

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):
    print("=== Stage 2: Prune (High-ER) & Fine-tuning ===")
 #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)
 #   print("Pruned encoder layers (highest ER):", enc_prune_idxs)
    print("Pruned decoder layers (highest ER):", dec_prune_idxs)
    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}")
    return model

# --- 5. Main Entrypoint ---

def main():
    data_files = {
        "train": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json",
        "validation": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json",
        "test": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json"
    }
    raw_datasets = load_dataset("json", data_files=data_files)
    tokenizer = T5TokenizerFast.from_pretrained("t5-base")
    label_texts = ["entailment", "neutral", "contradiction"]

    train_ds = raw_datasets["train"].shuffle(seed=42).select(range(10000))
    dev_ds = raw_datasets["validation"].shuffle(seed=42).select(range(2000))

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True, remove_columns=train_ds.column_names)
    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True, remove_columns=dev_ds.column_names)

    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(
        train_loader, dev_loader, device, tokenizer, label_texts)
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_er_scores, dec_er_scores, cross_er_scores,
        tokenizer, label_texts)

if __name__ == "__main__":
    main()


In [None]:
# Prune the decoder

# Mount Google Drive if on Colab
from google.colab import drive
drive.mount('/content/drive')

from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import warnings
import math

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Load CQA Data ---
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Functions ---

def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):
    # Prompt with or without CoT
    if use_cot and 'abstractive_explanation' in batch:
        # Use question, choices, and abstractive explanation for reasoning
        inputs = [
            f"question: {q} choices: {', '.join(choices)} rationale: {exp}"
            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])
        ]
    else:
        inputs = [
            f"question: {q} choices: {', '.join(choices)}"
            for q, choices in zip(batch['question'], batch['choices'])
        ]
    targets = [str(ans).strip() for ans in batch['answer']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(targets, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
USE_COT = False  # Set to True to include abstractive_explanation

train = dataset["train"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),
                            batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),
                            batched=True, remove_columns=dataset["test"].column_names)

collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=32, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev, batch_size=32, shuffle=False, collate_fn=collator)

# --- 3. Conditional ER Hook Utilities ---
def register_conditional_er_hooks(model):
    enc_layers = model.encoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    enc_hooks = []
    for i, layer in enumerate(enc_layers):
        def hook_fn_enc(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            enc_acts[idx] = hs.detach()
        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))
    dec_layers = model.decoder.block
    dec_acts = {i: None for i in range(len(dec_layers))}
    dec_hooks = []
    for i, layer in enumerate(dec_layers):
        def hook_fn_dec(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            dec_acts[idx] = hs.detach()
        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))
    cross_acts = {i: None for i in range(len(dec_layers))}
    cross_hooks = []
    for i, block in enumerate(dec_layers):
        def hook_fn_cross(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            cross_acts[idx] = hs.detach()
        cross_attn = block.layer[1]
        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)

def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks:
            h.remove()

def compute_conditional_batch_entropy(prev_acts, curr_acts):
    er_scores = {}
    for i in range(len(curr_acts) - 1):
        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]
        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]
        if (
            prev_X is not None and prev_Y is not None and
            curr_X is not None and curr_Y is not None and
            prev_X.shape == curr_X.shape and
            prev_Y.shape == curr_Y.shape
        ):
            B = curr_X.size(0)
            dX = (curr_X - prev_X).view(B, -1)
            dY = (curr_Y - prev_Y).view(B, -1)
            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):
                cos_squares = [
                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()
                    for j in range(1, B)
                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())
                ]
                if cos_squares:
                    er = sum(cos_squares) / len(cos_squares)
                    if not (math.isnan(er) or math.isinf(er)):
                        er_scores[i] = er
    return er_scores

# --- 4. Pruning Utilities ---
class SkipBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None, None, None, None)

def prune_er_layers(blocks, er_scores, num_prune=4, hidden_size=768):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]
    for idx in prune_idxs:
        blocks[idx] = SkipBlock(hidden_size)
    return prune_idxs

# --- 5. Training/Eval/ER Pipeline ---
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def full_finetuning(train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)
    last_enc_er, last_dec_er, last_cross_er = None, None, None

    for epoch in range(6):
        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)
        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)
        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)
        model.train()
        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            if prev_enc_acts is not None:
                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)
                for idx, v in enc_batch_er.items():
                    enc_er_sums[idx] += v
                    enc_er_counts[idx] += 1
            if prev_dec_acts is not None:
                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)
                for idx, v in dec_batch_er.items():
                    dec_er_sums[idx] += v
                    dec_er_counts[idx] += 1
            if prev_cross_acts is not None:
                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)
                for idx, v in cross_batch_er.items():
                    cross_er_sums[idx] += v
                    cross_er_counts[idx] += 1
            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}
            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}
            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}
        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}
        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}
        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}")
        print(f"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}")
        print(f"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])
    return model, last_enc_er, last_dec_er, last_cross_er

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):
    print("=== Stage 2: Prune (High-ER) & Fine-tuning ===")
 #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)
 #   print("Pruned encoder layers (highest ER):", enc_prune_idxs)
    print("Pruned decoder layers (highest ER):", dec_prune_idxs)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}")
    return model

# --- 6. Entrypoint ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(
        train_loader, dev_loader, device, tokenizer)
    # --- PRUNING AND CONTINUED FINETUNING ---
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_er_scores, dec_er_scores, tokenizer
    )

if __name__ == "__main__":
    main()


In [None]:
# Only prune decoder

# --- Mount Google Drive if using Colab ---
from google.colab import drive
drive.mount('/content/drive')

# --- Standard Imports ---
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup, Adafactor
)
from torch.cuda.amp import autocast
from collections import defaultdict
import warnings
import math
import random
import numpy as np

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ------------- Repro -------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(1234)

# --- 1. Load ANLI1 Dataset ---
data_files = {
    "train":      "/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json",
    "validation": "/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json",
    "test":       "/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Function ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    label_list = ["entailment", "neutral", "contradiction"]

    # robust label -> string
    labels_str = []
    for x in batch['label']:
        sx = str(x)
        if sx.isdigit() and int(sx) < 3:
            labels_str.append(label_list[int(sx)])
        else:
            labels_str.append(sx.strip().lower())

    # Fixed padding to keep hook tensor shapes consistent across steps
    model_inputs = tokenizer(
        inputs, padding="max_length", truncation=True, max_length=max_input_length
    )
    target = tokenizer(
        text_target=labels_str, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

# Tokenizer
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

# Map datasets
train = dataset["train"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["train"].column_names
)
dev = dataset["validation"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["validation"].column_names
)

# --- Load model before creating the collator (so collator can mask label pads -> -100) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
# Avoid dealing with past_key_value in custom blocks
model.config.use_cache = False

# Collator that converts pad tokens in labels to -100
collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, label_pad_token_id=-100
)

train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

# --- 3. Conditional ER Hook Utilities ---
def register_conditional_er_hooks(model):
    # Encoder hooks
    enc_layers = model.encoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    enc_hooks = []
    for i, layer in enumerate(enc_layers):
        def hook_fn_enc(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            enc_acts[idx] = hs.detach()
        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))

    # Decoder hooks
    dec_layers = model.decoder.block
    dec_acts = {i: None for i in range(len(dec_layers))}
    dec_hooks = []
    for i, layer in enumerate(dec_layers):
        def hook_fn_dec(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            dec_acts[idx] = hs.detach()
        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))

    # Cross-attention hooks (decoder.layer[1] is cross-attn in T5 decoder blocks)
    cross_acts = {i: None for i in range(len(dec_layers))}
    cross_hooks = []
    for i, block in enumerate(dec_layers):
        if hasattr(block, "layer") and len(block.layer) > 1:
            cross_attn = block.layer[1]
            def hook_fn_cross(module, inp, out, idx=i):
                hs = out[0] if isinstance(out, tuple) else out
                cross_acts[idx] = hs.detach()
            cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))
        else:
            cross_hooks.append(None)

    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)

def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks:
            if h is not None:
                h.remove()

def compute_conditional_batch_entropy(prev_acts, curr_acts):
    """
    Cos^2 between step-wise deltas, averaged over batch.
    """
    er_scores = {}
    for i in range(len(curr_acts) - 1):
        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]
        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]
        if (
            prev_X is not None and prev_Y is not None and
            curr_X is not None and curr_Y is not None and
            prev_X.shape == curr_X.shape and
            prev_Y.shape == curr_Y.shape
        ):
            B = curr_X.size(0)
            dX = (curr_X - prev_X).view(B, -1)
            dY = (curr_Y - prev_Y).view(B, -1)
            if B >= 1 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):
                cos_squares = []
                for j in range(B):
                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any()):
                        cs = F.cosine_similarity(
                            dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8
                        ).item()
                        if not (math.isnan(cs) or math.isinf(cs)):
                            cos_squares.append(cs * cs)  # cs^2
                if cos_squares:
                    er = sum(cos_squares) / len(cos_squares)
                    if not (math.isnan(er) or math.isinf(er)):
                        er_scores[i] = er
    return er_scores

# --- 4. Pruning Utilities ---
class SkipBlock(nn.Module):
    """
    Minimal drop-in replacement for a T5 decoder block that simply forwards hidden_states.
    Matches T5Block's call signature and return tuple:
    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,
     position_bias, encoder_decoder_position_bias)
    """
    def __init__(self):
        super().__init__()

    def forward(
        self,
        hidden_states,
        attention_mask=None,                 # (in decoder this is 'causal_mask' positionally)
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=False,
        cache_position=None,                 # <-- NEW to accept HF's kwarg
        **kwargs,
    ):
        # Simply pass through hidden_states and propagate positional biases
        present_key_value = None
        self_attn_weights = None
        cross_attn_weights = None
        return (
            hidden_states,
            present_key_value,
            self_attn_weights,
            cross_attn_weights,
            position_bias,
            encoder_decoder_position_bias,
        )

def prune_er_layers(blocks, er_scores, num_prune=4):
    # Sort by descending ER (your 'redundancy' convention)
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    # Shift +1 because pair (i,i+1) -> prune i+1
    prune_idxs = [idx + 1 for idx, _ in sorted_layers[:num_prune] if idx + 1 < len(blocks)]
    prune_idxs = sorted(set(prune_idxs))
    for idx in prune_idxs:
        blocks[idx] = SkipBlock()
    return prune_idxs

# --- 5. Training/Eval/ER Pipeline ---
def canonicalize_label(s: str):
    s = (s or "").strip().lower()
    first = s.split()[0] if s else s
    CANON = {
        "entailment": "entailment",
        "entailed": "entailment",
        "neutral": "neutral",
        "contradiction": "contradiction",
        "contradict": "contradiction",
        "contradictory": "contradiction",
        "contradicted": "contradiction",
    }
    return CANON.get(first, first)

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if canonicalize_label(p) == canonicalize_label(l):
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # decode gold labels to text
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        # let the model actually spell labels; turn off cache to stay consistent with SkipBlock
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4,       # enough for "contradiction"
            use_cache=False
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])

    return compute_accuracy(preds, refs)

def build_optimizer_and_scheduler(model, train_steps):
    """
    More stable defaults for T5:
      - Adafactor with relative_step=True is recommended by HF for T5.
      - If you prefer AdamW, lower LR (e.g., 5e-5) and enable grad clipping.
    """
    opt = Adafactor(
        model.parameters(),
        relative_step=True, scale_parameter=True, warmup_init=True
    )
    sched = None  # Adafactor w/ relative_step schedules internally
    return opt, sched

def full_finetuning(model, train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===")

    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)
    last_enc_er, last_dec_er, last_cross_er = None, None, None

    for epoch in range(6):
        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)
        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)
        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)

        model.train()
        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None

        for batch in train_loader:
            opt.zero_grad()

            with autocast(enabled=False):  # turn off AMP while debugging stability
                outputs = model(
                    input_ids=batch['input_ids'].to(device),
                    attention_mask=batch['attention_mask'].to(device),
                    labels=batch['labels'].to(device)
                )
                loss = outputs.loss

            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping this batch.")
                continue

            loss.backward()
            # Grad clipping (helps avoid LN/attn blowups)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Adafactor: step without scheduler
            opt.step()
            if sched is not None:
                sched.step()

            # --- ER accumulations ---
            if prev_enc_acts is not None:
                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)
                for idx, v in enc_batch_er.items():
                    enc_er_sums[idx] += v
                    enc_er_counts[idx] += 1

            if prev_dec_acts is not None:
                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)
                for idx, v in dec_batch_er.items():
                    dec_er_sums[idx] += v
                    dec_er_counts[idx] += 1

            if prev_cross_acts is not None:
                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)
                for idx, v in cross_batch_er.items():
                    cross_er_sums[idx] += v
                    cross_er_counts[idx] += 1

            # snapshot current acts
            prev_enc_acts = {i: (enc_acts[i].clone() if enc_acts[i] is not None else None) for i in enc_acts}
            prev_dec_acts = {i: (dec_acts[i].clone() if dec_acts[i] is not None else None) for i in dec_acts}
            prev_cross_acts = {i: (cross_acts[i].clone() if cross_acts[i] is not None else None) for i in cross_acts}

        # epoch-level ER (means)
        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}
        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}
        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}

        print(f"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}")
        print(f"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}")
        print(f"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}")

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])
    return model, last_enc_er, last_dec_er, last_cross_er

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):
    print("=== Stage 2: Prune (High-ER) & Fine-tuning ===")
    # Decoder-only pruning
    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4)
    print("Pruned decoder layers (highest ER -> next index):", dec_prune_idxs)

    # New optimizer after structural change
    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 2)

    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(
                input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_mask'].to(device),
                labels=batch['labels'].to(device)
            )
            loss = outputs.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping this batch.")
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            if sched is not None:
                sched.step()

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}")
    return model

# --- 6. Entrypoint ---
def main():
    global model  # using the earlier-loaded model
    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(
        model, train_loader, dev_loader, device, tokenizer
    )
    # --- PRUNING AND CONTINUED FINETUNING (decoder only) ---
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_er_scores, dec_er_scores, tokenizer
    )

if __name__ == "__main__":
    main()


In [None]:
# ER

# --- Mount Google Drive if using Colab ---
from google.colab import drive
drive.mount('/content/drive')

# --- Standard Imports ---
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup, Adafactor
)
from torch.cuda.amp import autocast
from collections import defaultdict
import warnings
import math
import random
import numpy as np

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ------------- Repro -------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(1234)

# --- 1. Load ANLI1 Dataset ---
data_files = {
    "train":      "/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json",
    "validation": "/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json",
    "test":       "/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Function ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    label_list = ["entailment", "neutral", "contradiction"]

    # robust label -> string
    labels_str = []
    for x in batch['label']:
        sx = str(x)
        if sx.isdigit() and int(sx) < 3:
            labels_str.append(label_list[int(sx)])
        else:
            labels_str.append(sx.strip().lower())

    # Fixed padding to keep hook tensor shapes consistent across steps
    model_inputs = tokenizer(
        inputs, padding="max_length", truncation=True, max_length=max_input_length
    )
    target = tokenizer(
        text_target=labels_str, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

# Tokenizer
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

# Map datasets
train = dataset["train"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["train"].column_names
)
dev = dataset["validation"].map(
    lambda ex: preprocess_anli(ex, tokenizer),
    batched=True, remove_columns=dataset["validation"].column_names
)

# --- Load model before creating the collator (so collator can mask label pads -> -100) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
# Avoid dealing with past_key_value in custom blocks
model.config.use_cache = False

# Collator that converts pad tokens in labels to -100
collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, label_pad_token_id=-100
)

train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)

# --- 3. Conditional ER Hook Utilities ---
def register_conditional_er_hooks(model):
    # Encoder hooks
    enc_layers = model.encoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    enc_hooks = []
    for i, layer in enumerate(enc_layers):
        def hook_fn_enc(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            enc_acts[idx] = hs.detach()
        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))

    # Decoder hooks
    dec_layers = model.decoder.block
    dec_acts = {i: None for i in range(len(dec_layers))}
    dec_hooks = []
    for i, layer in enumerate(dec_layers):
        def hook_fn_dec(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            dec_acts[idx] = hs.detach()
        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))

    # Cross-attention hooks (decoder.layer[1] is cross-attn in T5 decoder blocks)
    cross_acts = {i: None for i in range(len(dec_layers))}
    cross_hooks = []
    for i, block in enumerate(dec_layers):
        if hasattr(block, "layer") and len(block.layer) > 1:
            cross_attn = block.layer[1]
            def hook_fn_cross(module, inp, out, idx=i):
                hs = out[0] if isinstance(out, tuple) else out
                cross_acts[idx] = hs.detach()
            cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))
        else:
            cross_hooks.append(None)

    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)

def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks:
            if h is not None:
                h.remove()

def compute_conditional_batch_entropy(prev_acts, curr_acts):
    """
    Cos^2 between step-wise deltas, averaged over batch.
    """
    er_scores = {}
    for i in range(len(curr_acts) - 1):
        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]
        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]
        if (
            prev_X is not None and prev_Y is not None and
            curr_X is not None and curr_Y is not None and
            prev_X.shape == curr_X.shape and
            prev_Y.shape == curr_Y.shape
        ):
            B = curr_X.size(0)
            dX = (curr_X - prev_X).view(B, -1)
            dY = (curr_Y - prev_Y).view(B, -1)
            if B >= 1 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):
                cos_squares = []
                for j in range(B):
                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any()):
                        cs = F.cosine_similarity(
                            dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8
                        ).item() ** 2
                        if not (math.isnan(cs) or math.isinf(cs)):
                            cos_squares.append(cs * cs)  # cs^2
                if cos_squares:
                    er = sum(cos_squares) / len(cos_squares)
                    if not (math.isnan(er) or math.isinf(er)):
                        er_scores[i] = er
    return er_scores

# --- 4. Pruning Utilities ---
class SkipBlock(nn.Module):
    """
    Minimal drop-in replacement for a T5 decoder block that simply forwards hidden_states.
    Matches T5Block's call signature and return tuple:
    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,
     position_bias, encoder_decoder_position_bias)
    """
    def __init__(self):
        super().__init__()

    def forward(
        self,
        hidden_states,
        attention_mask=None,                 # (in decoder this is 'causal_mask' positionally)
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=False,
        cache_position=None,                 # <-- NEW to accept HF's kwarg
        **kwargs,
    ):
        # Simply pass through hidden_states and propagate positional biases
        present_key_value = None
        self_attn_weights = None
        cross_attn_weights = None
        return (
            hidden_states,
            present_key_value,
            self_attn_weights,
            cross_attn_weights,
            position_bias,
            encoder_decoder_position_bias,
        )

def prune_er_layers(blocks, er_scores, num_prune=4):
    # Sort by descending ER (your 'redundancy' convention)
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    # Shift +1 because pair (i,i+1) -> prune i+1
    prune_idxs = [idx + 1 for idx, _ in sorted_layers[:num_prune] if idx + 1 < len(blocks)]
    prune_idxs = sorted(set(prune_idxs))
    for idx in prune_idxs:
        blocks[idx] = SkipBlock()
    return prune_idxs

# --- 5. Training/Eval/ER Pipeline ---
def canonicalize_label(s: str):
    s = (s or "").strip().lower()
    first = s.split()[0] if s else s
    CANON = {
        "entailment": "entailment",
        "entailed": "entailment",
        "neutral": "neutral",
        "contradiction": "contradiction",
        "contradict": "contradiction",
        "contradictory": "contradiction",
        "contradicted": "contradiction",
    }
    return CANON.get(first, first)

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if canonicalize_label(p) == canonicalize_label(l):
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

@torch.no_grad()
def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # decode gold labels to text
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        # let the model actually spell labels; turn off cache to stay consistent with SkipBlock
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4,       # enough for "contradiction"
            use_cache=False
        )
        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        preds.extend([p.strip().lower() for p in pred_texts])
        refs.extend([l.strip().lower() for l in ref_texts])

    return compute_accuracy(preds, refs)

def build_optimizer_and_scheduler(model, train_steps):
    """
    More stable defaults for T5:
      - Adafactor with relative_step=True is recommended by HF for T5.
      - If you prefer AdamW, lower LR (e.g., 5e-5) and enable grad clipping.
    """
    opt = Adafactor(
        model.parameters(),
        relative_step=True, scale_parameter=True, warmup_init=True
    )
    sched = None  # Adafactor w/ relative_step schedules internally
    return opt, sched

def full_finetuning(model, train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===")

    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)
    last_enc_er, last_dec_er, last_cross_er = None, None, None

    for epoch in range(6):
        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)
        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)
        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)

        model.train()
        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None

        for batch in train_loader:
            opt.zero_grad()

            with autocast(enabled=False):  # turn off AMP while debugging stability
                outputs = model(
                    input_ids=batch['input_ids'].to(device),
                    attention_mask=batch['attention_mask'].to(device),
                    labels=batch['labels'].to(device)
                )
                loss = outputs.loss

            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping this batch.")
                continue

            loss.backward()
            # Grad clipping (helps avoid LN/attn blowups)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Adafactor: step without scheduler
            opt.step()
            if sched is not None:
                sched.step()

            # --- ER accumulations ---
            if prev_enc_acts is not None:
                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)
                for idx, v in enc_batch_er.items():
                    enc_er_sums[idx] += v
                    enc_er_counts[idx] += 1

            if prev_dec_acts is not None:
                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)
                for idx, v in dec_batch_er.items():
                    dec_er_sums[idx] += v
                    dec_er_counts[idx] += 1

            if prev_cross_acts is not None:
                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)
                for idx, v in cross_batch_er.items():
                    cross_er_sums[idx] += v
                    cross_er_counts[idx] += 1

            # snapshot current acts
            prev_enc_acts = {i: (enc_acts[i].clone() if enc_acts[i] is not None else None) for i in enc_acts}
            prev_dec_acts = {i: (dec_acts[i].clone() if dec_acts[i] is not None else None) for i in dec_acts}
            prev_cross_acts = {i: (cross_acts[i].clone() if cross_acts[i] is not None else None) for i in cross_acts}

        # epoch-level ER (means)
        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}
        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}
        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}

        print(f"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}")
        print(f"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}")
        print(f"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}")

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")

        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])
    return model, last_enc_er, last_dec_er, last_cross_er

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):
    print("=== Stage 2: Prune (High-ER) & Fine-tuning ===")
    # Decoder-only pruning
    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4)
    print("Pruned decoder layers (highest ER -> next index):", dec_prune_idxs)

    # New optimizer after structural change
    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 2)

    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(
                input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_mask'].to(device),
                labels=batch['labels'].to(device)
            )
            loss = outputs.loss
            if not torch.isfinite(loss):
                print("Loss is NaN/Inf — skipping this batch.")
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            if sched is not None:
                sched.step()

        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}")
    return model

# --- 6. Entrypoint ---
def main():
    global model  # using the earlier-loaded model
    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(
        model, train_loader, dev_loader, device, tokenizer
    )
    # --- PRUNING AND CONTINUED FINETUNING (decoder only) ---
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_er_scores, dec_er_scores, tokenizer
    )

if __name__ == "__main__":
    main()


In [None]:
# Only prune decoder
# ===========================
# 0. Google Drive Mount
# ===========================
from google.colab import drive
drive.mount('/content/drive')

# ===========================
# 1. Imports and Setup
# ===========================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import warnings
import math

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ===========================
# 2. Load SVAMP Dataset
# ===========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json",
    "test": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# ===========================
# 3. Preprocessing
# ===========================
def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):
    model_inputs = tokenizer(
        batch["input"], padding="max_length", truncation=True, max_length=max_input_length
    )
    targets = [str(x) for x in batch["label"]]
    target_encodings = tokenizer(
        targets, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target_encodings["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["train"].column_names)
dev = dataset["test"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["test"].column_names)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
dev_loader = DataLoader(dev, batch_size=8, shuffle=False, collate_fn=collator)

# ===========================
# 4. Conditional ER Utilities
# ===========================
def register_conditional_er_hooks(model):
    enc_layers = model.encoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    enc_hooks = []
    for i, layer in enumerate(enc_layers):
        def hook_fn_enc(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            enc_acts[idx] = hs.detach()
        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))
    dec_layers = model.decoder.block
    dec_acts = {i: None for i in range(len(dec_layers))}
    dec_hooks = []
    for i, layer in enumerate(dec_layers):
        def hook_fn_dec(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            dec_acts[idx] = hs.detach()
        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))
    cross_acts = {i: None for i in range(len(dec_layers))}
    cross_hooks = []
    for i, block in enumerate(dec_layers):
        def hook_fn_cross(module, inp, out, idx=i):
            hs = out[0] if isinstance(out, tuple) else out
            cross_acts[idx] = hs.detach()
        cross_attn = block.layer[1]
        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)

def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks:
            h.remove()

def compute_conditional_batch_entropy(prev_acts, curr_acts):
    er_scores = {}
    for i in range(len(curr_acts) - 1):
        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]
        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]
        if (
            prev_X is not None and prev_Y is not None and
            curr_X is not None and curr_Y is not None and
            prev_X.shape == curr_X.shape and
            prev_Y.shape == curr_Y.shape
        ):
            B = curr_X.size(0)
            dX = (curr_X - prev_X).view(B, -1)
            dY = (curr_Y - prev_Y).view(B, -1)
            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):
                cos_squares = [
                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()
                    for j in range(1, B)
                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())
                ]
                if cos_squares:
                    er = sum(cos_squares) / len(cos_squares)
                    if not (math.isnan(er) or math.isinf(er)):
                        er_scores[i] = er
    return er_scores

# ===========================
# 5. Pruning Utilities
# ===========================
class SkipBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states, *args, **kwargs):
        return (hidden_states, None, None, None, None, None)

def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]
    for idx in prune_idxs:
        blocks[idx] = SkipBlock(hidden_size)
    return prune_idxs

# ===========================
# 6. Eval Helper
# ===========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# ===========================
# 7. Training + ER Tracking + Pruning
# ===========================
def full_finetuning(train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)
    last_enc_er, last_dec_er, last_cross_er = None, None, None

    for epoch in range(6):
        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)
        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)
        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)
        model.train()
        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            if prev_enc_acts is not None:
                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)
                for idx, v in enc_batch_er.items():
                    enc_er_sums[idx] += v
                    enc_er_counts[idx] += 1
            if prev_dec_acts is not None:
                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)
                for idx, v in dec_batch_er.items():
                    dec_er_sums[idx] += v
                    dec_er_counts[idx] += 1
            if prev_cross_acts is not None:
                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)
                for idx, v in cross_batch_er.items():
                    cross_er_sums[idx] += v
                    cross_er_counts[idx] += 1
            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}
            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}
            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}
        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}
        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}
        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}")
        print(f"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}")
        print(f"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])
    return model, last_enc_er, last_dec_er, last_cross_er

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):
    print("=== Stage 2: Prune (High-ER) & Fine-tuning ===")
 #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=2, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)
 #   print("Pruned encoder layers (highest ER):", enc_prune_idxs)
    print("Pruned decoder layers (highest ER):", dec_prune_idxs)
    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}")
    return model

# ===========================
# 8. Main Entrypoint
# ===========================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(
        train_loader, dev_loader, device, tokenizer)
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_er_scores, dec_er_scores, tokenizer
    )

if __name__ == "__main__":
    main()
