In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
import numpy as np
import random
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import warnings
from collections import defaultdict

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- SkipFFN for T5 FFN dropout ---
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

# --- LayerDrop utility for random pruning ---
def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):
    enc_blocks = model.encoder.block
    dec_blocks = model.decoder.block
    total_enc = len(enc_blocks)
    total_dec = len(dec_blocks)
    d_model = model.config.d_model

    rng = np.random.default_rng(seed)
    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)
    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)
    enc_idxs = sorted(enc_idxs)
    dec_idxs = sorted(dec_idxs)

    for idx in enc_idxs:
        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)
    for idx in dec_idxs:
        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)
    print(f"LayerDrop (Encoder FFN): pruned layers {enc_idxs}")
    print(f"LayerDrop (Decoder FFN): pruned layers {dec_idxs}")
    return enc_idxs, dec_idxs

# --- Data/Helper functions ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device, label_texts):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# --- Training/Fine-tuning Loop ---
def finetune_t5(train_loader, dev_loader, device, tokenizer, label_texts, epochs=6, lr=3e-4, model=None):
    if model is None:
        model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}")
    return model

# --- Entrypoint ---
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    data_files = {
        "train": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json",
        "validation": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json",
        "test": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json"
    }
    raw_datasets = load_dataset("json", data_files=data_files)
    tokenizer = T5TokenizerFast.from_pretrained("t5-base")
    label_texts = ["entailment", "neutral", "contradiction"]

    train_ds = raw_datasets["train"].shuffle(seed=seed).select(range(10000))
    dev_ds = raw_datasets["validation"].shuffle(seed=seed).select(range(2000))

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True, remove_columns=train_ds.column_names)
    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True, remove_columns=dev_ds.column_names)

    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Stage 1: Full Fine-Tuning ---
    print("\n=== Stage 1: Full Fine-Tuning (No Pruning) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, label_texts, epochs=6)
    acc_full = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
    print(f"\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}")

    # --- Stage 2: LayerDrop Pruning ---
    print("\n=== Stage 2: LayerDrop Pruning  ===")
    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)

    # --- Stage 3: Fine-tune Pruned Model ---
    print("\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, label_texts, epochs=5, lr=5e-4, model=model)
    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
    print(f"\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}")
    print(f"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}")

if __name__ == "__main__":
    main()


In [None]:
# Mount Google Drive if on Colab
from google.colab import drive
drive.mount('/content/drive')

from datasets import load_dataset
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Load CQA Data ---
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Functions ---
def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):
    if use_cot and 'abstractive_explanation' in batch:
        inputs = [
            f"question: {q} choices: {', '.join(choices)} rationale: {exp}"
            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])
        ]
    else:
        inputs = [
            f"question: {q} choices: {', '.join(choices)}"
            for q, choices in zip(batch['question'], batch['choices'])
        ]
    targets = [str(ans).strip() for ans in batch['answer']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(targets, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
USE_COT = False

train = dataset["train"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),
                            batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),
                            batched=True, remove_columns=dataset["test"].column_names)

collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

# --- 3. SkipFFN utility for random dropout ---
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):
    enc_blocks = model.encoder.block
    dec_blocks = model.decoder.block
    total_enc = len(enc_blocks)
    total_dec = len(dec_blocks)
    d_model = model.config.d_model

    rng = np.random.default_rng(seed)
    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)
    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)
    enc_idxs = sorted(enc_idxs)
    dec_idxs = sorted(dec_idxs)

    for idx in enc_idxs:
        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)
    for idx in dec_idxs:
        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)
    print(f"LayerDrop (Encoder FFN): pruned layers {enc_idxs}")
    print(f"LayerDrop (Decoder FFN): pruned layers {dec_idxs}")
    return enc_idxs, dec_idxs

# --- 4. Training/Evaluation ---
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6, lr=3e-4, model=None):
    if model is None:
        model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] CQA Acc: {acc:.4f}")
    return model

# --- 5. Entrypoint ---
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Stage 1: Full Fine-Tuning ---
    print("\n=== Stage 1: Full Fine-Tuning (No Pruning) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6)
    acc_full = evaluate_model(model, dev_loader, tokenizer, device)
    print(f"\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}")

    # --- Stage 2: LayerDrop Pruning ---
    print("\n=== Stage 2: LayerDrop Pruning ===")
    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)

    # --- Stage 3: Fine-tune Pruned Model ---
    print("\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=5, lr=5e-4, model=model)
    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device)
    print(f"\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}")
    print(f"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}")

if __name__ == "__main__":
    main()


In [None]:
# --- Mount Google Drive if using Colab ---
from google.colab import drive
drive.mount('/content/drive')

# --- Standard Imports ---
from datasets import load_dataset
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Load ANLI1 Dataset ---
data_files = {
    "train":      "/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json",
    "validation": "/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json",
    "test":       "/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Function ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[int(x)] if isinstance(x, (int, float, str)) and str(x).isdigit() and int(x)<3 else str(x) for x in batch['label']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["validation"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset["validation"].column_names)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

# --- 3. SkipFFN for random dropout ---
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):
    enc_blocks = model.encoder.block
    dec_blocks = model.decoder.block
    total_enc = len(enc_blocks)
    total_dec = len(dec_blocks)
    d_model = model.config.d_model

    rng = np.random.default_rng(seed)
    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)
    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)
    enc_idxs = sorted(enc_idxs)
    dec_idxs = sorted(dec_idxs)

    for idx in enc_idxs:
        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)
    for idx in dec_idxs:
        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)
    print(f"LayerDrop (Encoder FFN): pruned layers {enc_idxs}")
    print(f"LayerDrop (Decoder FFN): pruned layers {dec_idxs}")
    return enc_idxs, dec_idxs

# --- 4. Training/Evaluation ---
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6, lr=3e-4, model=None):
    if model is None:
        model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}")
    return model

# --- 5. Entrypoint ---
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Stage 1: Full Fine-Tuning ---
    print("\n=== Stage 1: Full Fine-Tuning (No Pruning) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6)
    acc_full = evaluate_model(model, dev_loader, tokenizer, device)
    print(f"\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}")

    # --- Stage 2: LayerDrop Pruning ---
    print("\n=== Stage 2: LayerDrop Pruning ===")
    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)

    # --- Stage 3: Fine-tune Pruned Model ---
    print("\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=5, lr=5e-4, model=model)
    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device)
    print(f"\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}")
    print(f"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}")

if __name__ == "__main__":
    main()


In [None]:
# ===========================
# 0. Google Drive Mount
# ===========================
from google.colab import drive
drive.mount('/content/drive')

# ===========================
# 1. Imports and Setup
# ===========================
from datasets import load_dataset
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ===========================
# 2. Load SVAMP Dataset
# ===========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json",
    "test": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# ===========================
# 3. Preprocessing
# ===========================
def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):
    model_inputs = tokenizer(
        batch["input"], padding="max_length", truncation=True, max_length=max_input_length
    )
    targets = [str(x) for x in batch["label"]]
    target_encodings = tokenizer(
        targets, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target_encodings["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["train"].column_names)
dev = dataset["test"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["test"].column_names)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

# ===========================
# 4. SkipFFN & LayerDropout Utilities
# ===========================
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):
    enc_blocks = model.encoder.block
    dec_blocks = model.decoder.block
    total_enc = len(enc_blocks)
    total_dec = len(dec_blocks)
    d_model = model.config.d_model

    rng = np.random.default_rng(seed)
    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)
    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)
    enc_idxs = sorted(enc_idxs)
    dec_idxs = sorted(dec_idxs)

    for idx in enc_idxs:
        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)
    for idx in dec_idxs:
        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)
    print(f"LayerDrop (Encoder FFN): pruned layers {enc_idxs}")
    print(f"LayerDrop (Decoder FFN): pruned layers {dec_idxs}")
    return enc_idxs, dec_idxs

# ===========================
# 5. Training/Eval
# ===========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6, lr=3e-4, model=None):
    if model is None:
        model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] SVAMP Acc: {acc:.4f}")
    return model

# ===========================
# 6. Main Entrypoint
# ===========================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Stage 1: Full Fine-Tuning ---
    print("\n=== Stage 1: Full Fine-Tuning (No Pruning) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6)
    acc_full = evaluate_model(model, dev_loader, tokenizer, device)
    print(f"\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}")

    # --- Stage 2: LayerDrop Pruning ---
    print("\n=== Stage 2: LayerDrop Pruning ===")
    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)

    # --- Stage 3: Fine-tune Pruned Model ---
    print("\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===")
    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=5, lr=5e-4, model=model)
    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device)
    print(f"\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}")
    print(f"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}")

if __name__ == "__main__":
    main()
