In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import numpy as np
import random
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ---- 1. Knowledge Entropy Hook Utilities ----
from functools import partial

def register_ke_hooks_t5(model):
    enc_layers = model.encoder.block
    dec_layers = model.decoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    dec_acts = {i: None for i in range(len(dec_layers))}
    enc_hooks, dec_hooks = [], []

    for i, layer in enumerate(enc_layers):
        enc_hooks.append(
            layer.layer[1].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)
            )
        )
    for i, layer in enumerate(dec_layers):
        dec_hooks.append(
            layer.layer[2].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)
            )
        )
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)



def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):
    ke = {}
    for idx, a in acts.items():
        if a is None:
            continue
        act = act_fn(a)  # shape: (batch, seq, hidden)
        # Sum over the last dimension (hidden), keep dims for broadcast
        denom = act.sum(dim=-1, keepdim=True)
        # To avoid division by zero, set any zero denominators to 1
        denom = torch.where(denom == 0, torch.ones_like(denom), denom)
        probs = act / (denom + eps)
        # Clamp probabilities to avoid log(0)
        probs = torch.clamp(probs, min=1e-8)
        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()
        ke[idx] = entropy.item()
        acts[idx] = None  # reset
    return ke

def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks: h.remove()

# ---- 2. Pruning Utilities ----
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def prune_high_ke_ffn(blocks, ke_scores, num_prune=2, hidden_size=768):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        # Replace feed-forward block with Identity/Skip
        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)
    return prune_idxs

# ---- 3. Data/Helper functions ----
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device, label_texts):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# ---- 4. Training/Fine-tuning Loops ----
def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):
    print("=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)
    last_enc_ke, last_dec_ke = None, None

    for epoch in range(6):
        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)
        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            # KE estimation
            batch_enc_ke = compute_ke_batch(enc_acts)
            for idx, v in batch_enc_ke.items():
                enc_ke_sum[idx] += v
                enc_ke_count[idx] += 1
            batch_dec_ke = compute_ke_batch(dec_acts)
            for idx, v in batch_dec_ke.items():
                dec_ke_sum[idx] += v
                dec_ke_count[idx] += 1

        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}
        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}")
        print(f"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}")
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])
    return model, last_enc_ke, last_dec_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer, label_texts):
    print("=== Stage 2: Prune (High-KE) & Fine-tuning ===")
    # You can set num_prune as you wish
    num_prune = 4
   # enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=num_prune, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=num_prune, hidden_size=model.config.d_model)
   # print("Pruned encoder layers (highest KE):", enc_prune_idxs)
    print("Pruned decoder layers (highest KE):", dec_prune_idxs)

    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)
        print(f"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}")
    return model

# ---- 5. Entrypoint ----
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    data_files = {
        "train": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json",
        "validation": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json",
        "test": "/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json"
    }
    raw_datasets = load_dataset("json", data_files=data_files)
    tokenizer = T5TokenizerFast.from_pretrained("t5-base")
    label_texts = ["entailment", "neutral", "contradiction"]

    train_ds = raw_datasets["train"].shuffle(seed=seed).select(range(10000))
    dev_ds = raw_datasets["validation"].shuffle(seed=seed).select(range(2000))

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True, remove_columns=train_ds.column_names)
    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True, remove_columns=dev_ds.column_names)

    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer, label_texts)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Your data
enc_ke_epochs = [
    {0: 5.120778652954102, 1: 5.403757721710205, 2: 5.443886557769775, 3: 5.44828851776123, 4: 5.453934421539307, 5: 5.452359628295898, 6: 5.457024213409424, 7: 5.43810365524292, 8: 5.402327305603027, 9: 5.370745602416992, 10: 5.31574527053833, 11: 5.390125755310058},
    {0: 5.1113538009643555, 1: 5.392516478729248, 2: 5.438380133056641, 3: 5.431306089782715, 4: 5.4405796981811525, 5: 5.4482001785278324, 6: 5.4542674026489255, 7: 5.456199799346924, 8: 5.428633112335205, 9: 5.4060304084777835, 10: 5.376359208679199, 11: 5.4560540000915525},
    {0: 5.1189261985778804, 1: 5.415700047302246, 2: 5.461237496948242, 3: 5.445635594177246, 4: 5.448134069061279, 5: 5.454518992614746, 6: 5.457215517425537, 7: 5.458963891601562, 8: 5.441248722839355, 9: 5.416132154083252, 10: 5.3972449882507325, 11: 5.488359162902832},
    {0: 5.116845191955567, 1: 5.412102639007569, 2: 5.457377126312256, 3: 5.440801442718506, 4: 5.446412419891358, 5: 5.4529715980529785, 6: 5.4542711631774905, 7: 5.458785958862305, 8: 5.443267332458496, 9: 5.41692006149292, 10: 5.398646075439453, 11: 5.48984780960083},
    {0: 5.116689331817627, 1: 5.412071422576904, 2: 5.457381925201416, 3: 5.440740663146973, 4: 5.446404431915283, 5: 5.453024647521973, 6: 5.454191885375977, 7: 5.45875881652832, 8: 5.443202964019775, 9: 5.416829935455322, 10: 5.398244989776611, 11: 5.489683660888672},
    {0: 5.116260417938232, 1: 5.411794747924804, 2: 5.457123635864257, 3: 5.440602420806885, 4: 5.44625763092041, 5: 5.452800505065918, 6: 5.454110127258301, 7: 5.458805889892578, 8: 5.443212223052979, 9: 5.4169437850952145, 10: 5.3985387763977055, 11: 5.489740788269043}
]
dec_ke_epochs = [
    {0: 5.473455963897705, 1: 5.18053843383789, 2: 5.033648477172852, 3: 4.812303190612793, 4: 5.261005407714844, 5: 5.501210649108887, 6: 5.472148556518555, 7: 5.517318391418457, 8: 5.553192390441895, 9: 5.512592239379883, 10: 5.584075408172607, 11: 5.50434167175293},
    {0: 5.472483442687988, 1: 5.169460816192627, 2: 5.043915207672119, 3: 4.82118962020874, 4: 5.270837689971924, 5: 5.517090723419189, 6: 5.498110526275635, 7: 5.5361996917724605, 8: 5.564452503204346, 9: 5.528676497650147, 10: 5.596128713989258, 11: 5.526001943969726},
    {0: 5.469658322906494, 1: 5.168160199737549, 2: 5.06294663772583, 3: 4.844052698516846, 4: 5.287757330322266, 5: 5.529096306610107, 6: 5.515416881561279, 7: 5.5526224166870115, 8: 5.578667069244385, 9: 5.550772470855713, 10: 5.610965402221679, 11: 5.548971197509766},
    {0: 5.468696701049804, 1: 5.168712305450439, 2: 5.070296308898926, 3: 4.852590578460694, 4: 5.291568259429932, 5: 5.531682807922364, 6: 5.5185763488769535, 7: 5.554903052520752, 8: 5.579686510467529, 9: 5.552240438842773, 10: 5.613069982147217, 11: 5.553303944396973},
    {0: 5.468258437347412, 1: 5.169210543060303, 2: 5.070472750091553, 3: 4.852741773223877, 4: 5.2915089881896975, 5: 5.531434827423095, 6: 5.518463079833984, 7: 5.554849634552002, 8: 5.5797720138549805, 9: 5.552255963897705, 10: 5.61289744644165, 11: 5.553331785583496},
    {0: 5.468696504974365, 1: 5.168731469726563, 2: 5.070022554016114, 3: 4.852040791320801, 4: 5.291398749542236, 5: 5.531332643890381, 6: 5.51825998916626, 7: 5.554569575500488, 8: 5.579519750976562, 9: 5.552110796356201, 10: 5.613056344604492, 11: 5.553367820739746}
]

layers = list(range(1, 13))  # 1-based layers for x-axis
epochs = [f"Epoch {i+1}" for i in range(6)]

# Convert dicts to lists (layer 1-12)
enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]
dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]

# --- Encoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
#plt.title("Encoder KE vs Layer", fontsize=18)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- Decoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
#plt.title("Decoder KE vs Layer", fontsize=18)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()




In [None]:
# Mount Google Drive if on Colab
from google.colab import drive
drive.mount('/content/drive')

from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import numpy as np
import random
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Load CQA Data ---
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json",
    "test":  "/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Functions ---
def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):
    if use_cot and 'abstractive_explanation' in batch:
        inputs = [
            f"question: {q} choices: {', '.join(choices)} rationale: {exp}"
            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])
        ]
    else:
        inputs = [
            f"question: {q} choices: {', '.join(choices)}"
            for q, choices in zip(batch['question'], batch['choices'])
        ]
    targets = [str(ans).strip() for ans in batch['answer']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(targets, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
USE_COT = False

train = dataset["train"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),
                            batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["test"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),
                            batched=True, remove_columns=dataset["test"].column_names)

collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

# --- 3. Knowledge Entropy Hook Utilities ---
from functools import partial

def register_ke_hooks_t5(model):
    enc_layers = model.encoder.block
    dec_layers = model.decoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    dec_acts = {i: None for i in range(len(dec_layers))}
    enc_hooks, dec_hooks = [], []

    for i, layer in enumerate(enc_layers):
        enc_hooks.append(
            layer.layer[1].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)
            )
        )
    for i, layer in enumerate(dec_layers):
        dec_hooks.append(
            layer.layer[2].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)
            )
        )
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)


def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):
    ke = {}
    for idx, a in acts.items():
        acts[idx] = None  # Always reset, even if skipping
        if a is None:
            continue
        if not torch.isfinite(a).all():
            continue
        if a.numel() == 0 or a.abs().sum() == 0:
            continue
        act = act_fn(a)
        denom = act.sum(dim=-1, keepdim=True)
        denom = torch.where(denom == 0, torch.ones_like(denom), denom)
        probs = act / (denom + eps)
        probs = torch.clamp(probs, min=1e-8)
        if not torch.isfinite(probs).all():
            continue
        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()
        if not torch.isfinite(entropy):
            continue
        ke[idx] = entropy.item()
    return ke



def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks: h.remove()

# --- 4. Pruning Utilities ---
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def prune_high_ke_ffn(blocks, ke_scores, num_prune=4, hidden_size=768):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)
    return prune_idxs

# --- 5. Training/Eval/KE Pipeline ---
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def full_finetuning(train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)
    last_enc_ke, last_dec_ke = None, None

    for epoch in range(6):
        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)
        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            batch_enc_ke = compute_ke_batch(enc_acts)
            for idx, v in batch_enc_ke.items():
                enc_ke_sum[idx] += v
                enc_ke_count[idx] += 1
            batch_dec_ke = compute_ke_batch(dec_acts)
            for idx, v in batch_dec_ke.items():
                dec_ke_sum[idx] += v
                dec_ke_count[idx] += 1

        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}
        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}")
        print(f"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])
    return model, last_enc_ke, last_dec_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer):
    print("=== Stage 2: Prune (High-KE) & Fine-tuning ===")
#    enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=4, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=4, hidden_size=model.config.d_model)
#    print("Pruned encoder layers (highest KE):", enc_prune_idxs)
    print("Pruned decoder layers (highest KE):", dec_prune_idxs)

    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}")
    return model

# --- 6. Entrypoint ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer)
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_ke, dec_ke, tokenizer
    )

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --- Data ---
enc_ke_epochs = [
    {0: 5.139904764485476, 1: 5.401092627952838, 2: 5.439678460702129, 3: 5.432176997117417, 4: 5.442642464817842, 5: 5.4454231058631235, 6: 5.463585686018119, 7: 5.456152987989103, 8: 5.42602445652528, 9: 5.3984725964676175, 10: 5.365565290591987, 11: 5.397532289749698},
    {0: 5.139600602668299, 1: 5.4082865292215585, 2: 5.435035413513434, 3: 5.431703077357, 4: 5.4357455140851405, 5: 5.447143929736759, 6: 5.4573999727496565, 7: 5.448999722798665, 8: 5.4132760177887915, 9: 5.375146550302239, 10: 5.331545272288456, 11: 5.359056867131291},
    {0: 5.132425843788485, 1: 5.400260408523635, 2: 5.425296262763013, 3: 5.425646370071887, 4: 5.4349209725954655, 5: 5.449994925794931, 6: 5.459187375305126, 7: 5.451061400678162, 8: 5.413698513519588, 9: 5.36510266693942, 10: 5.306283406631896, 11: 5.328084006954734},
    {0: 5.132674681533538, 1: 5.40307938993858, 2: 5.426535778640722, 3: 5.425935121984121, 4: 5.435641975434152, 5: 5.449571485785624, 6: 5.459163357275852, 7: 5.450666630405119, 8: 5.419207827406759, 9: 5.381928584062798, 10: 5.340362237983541, 11: 5.369139033193854},
    {0: 5.13272574069269, 1: 5.403113040235047, 2: 5.426666823513989, 3: 5.425940881613244, 4: 5.435654991757498, 5: 5.449515212737085, 6: 5.459123445457621, 7: 5.450610509255445, 8: 5.419158941027762, 9: 5.3816367816455255, 10: 5.33981858255045, 11: 5.36847594375485},
    {0: 5.132623068021828, 1: 5.402830285196038, 2: 5.426464314922715, 3: 5.425696338534551, 4: 5.435371085732245, 5: 5.4492548917510435, 6: 5.4588225204956355, 7: 5.450303058123158, 8: 5.419105227553394, 9: 5.3820987663832796, 10: 5.341227367789483, 11: 5.37036663649098}
]
dec_ke_epochs = [
    {0: 5.4599811254363315, 1: 5.3001703635642405, 2: 5.1401081963589315, 3: 5.0207898828544115, 4: 5.296078369021416, 5: 5.529929007354536, 6: 5.507063817821051, 7: 5.562702250323798, 8: 5.5851923988053676, 9: 5.551125566426077, 10: 5.5910814561341935, 11: 5.5818385175968475},
    {0: 5.456549143280582, 1: 5.250347980361008, 2: 5.085265902750182, 3: 4.978233239992447, 4: 5.264991768895107, 5: 5.5131214478066966, 6: 5.495728558724166, 7: 5.558811952097608, 8: 5.581217587485148, 9: 5.551561930623439, 10: 5.592714675765061, 11: 5.592012528729007},
    {0: 5.453973231929363, 1: 5.239954027012236, 2: 5.083703805117717, 3: 4.982208940455623, 4: 5.268532490966344, 5: 5.5198255513760905, 6: 5.506296765292832, 7: 5.569196505121665, 8: 5.591176272225459, 9: 5.5623457416056015, 10: 5.598043118373002, 11: 5.60101840440983},
    {0: 5.452516633301533, 1: 5.234018758795727, 2: 5.084256347568556, 3: 4.9811194565495835, 4: 5.2666547043961645, 5: 5.518136622674751, 6: 5.504595130143691, 7: 5.568047467906682, 8: 5.590607944576219, 9: 5.562076567037548, 10: 5.598141102563767, 11: 5.602112366060905},
    {0: 5.452368998566676, 1: 5.233688806273863, 2: 5.083870511141121, 3: 4.980399314993121, 4: 5.265982374181888, 5: 5.517764004580493, 6: 5.503982042052671, 7: 5.567702337243091, 8: 5.589939683528956, 9: 5.561352815925586, 10: 5.59754676850167, 11: 5.601407691567206},
    {0: 5.4520845942149885, 1: 5.233260017357125, 2: 5.082963084542988, 3: 4.979661354165993, 4: 5.265588265381112, 5: 5.517923472732898, 6: 5.504299057240518, 7: 5.567891064858594, 8: 5.590223931318877, 9: 5.561733258481057, 10: 5.597703276091064, 11: 5.601514090765391}
]

layers = list(range(1, 13))  # 1-based layers for x-axis
epochs = [f"Epoch {i+1}" for i in range(6)]

# Convert dicts to lists (layer 1-12)
enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]
dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]

# --- Encoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- Decoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Decoder Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()


In [None]:
# --- Mount Google Drive if using Colab ---
from google.colab import drive
drive.mount('/content/drive')

# --- Standard Imports ---
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# --- 1. Load ANLI1 Dataset ---
data_files = {
    "train":      "/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json",
    "validation": "/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json",
    "test":       "/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# --- 2. Preprocessing Function ---
def make_t5_nli_prompt(premise, hypothesis):
    return f"nli premise: {premise} hypothesis: {hypothesis}"

def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):
    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
    label_list = ["entailment", "neutral", "contradiction"]
    labels = [label_list[int(x)] if isinstance(x, (int, float, str)) and str(x).isdigit() and int(x)<3 else str(x) for x in batch['label']]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_input_length)
    target = tokenizer(labels, padding="max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = target["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset["train"].column_names)
dev   = dataset["validation"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset["validation"].column_names)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

# --- 3. Knowledge Entropy Hook Utilities ---
from functools import partial

def register_ke_hooks_t5(model):
    enc_layers = model.encoder.block
    dec_layers = model.decoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    dec_acts = {i: None for i in range(len(dec_layers))}
    enc_hooks, dec_hooks = [], []

    for i, layer in enumerate(enc_layers):
        enc_hooks.append(
            layer.layer[1].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)
            )
        )
    for i, layer in enumerate(dec_layers):
        dec_hooks.append(
            layer.layer[2].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)
            )
        )
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)




def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):
    ke = {}
    for idx, a in acts.items():
        acts[idx] = None  # Always reset, even if skipping
        if a is None:
            continue
        if not torch.isfinite(a).all():
            continue
        if a.numel() == 0 or a.abs().sum() == 0:
            continue
        act = act_fn(a)
        denom = act.sum(dim=-1, keepdim=True)
        denom = torch.where(denom == 0, torch.ones_like(denom), denom)
        probs = act / (denom + eps)
        probs = torch.clamp(probs, min=1e-8)
        if not torch.isfinite(probs).all():
            continue
        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()
        if not torch.isfinite(entropy):
            continue
        ke[idx] = entropy.item()
    return ke



def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks: h.remove()

# --- 4. Pruning Utilities ---
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def prune_high_ke_ffn(blocks, ke_scores, num_prune=4, hidden_size=768):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)
    return prune_idxs

# --- 5. Training/Eval/KE Pipeline ---
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

def full_finetuning(train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)
    last_enc_ke, last_dec_ke = None, None

    for epoch in range(6):
        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)
        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            batch_enc_ke = compute_ke_batch(enc_acts)
            for idx, v in batch_enc_ke.items():
                enc_ke_sum[idx] += v
                enc_ke_count[idx] += 1
            batch_dec_ke = compute_ke_batch(dec_acts)
            for idx, v in batch_dec_ke.items():
                dec_ke_sum[idx] += v
                dec_ke_count[idx] += 1

        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}
        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}")
        print(f"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])
    return model, last_enc_ke, last_dec_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer):
    print("=== Stage 2: Prune (High-KE) & Fine-tuning ===")
#    enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=4, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=4, hidden_size=model.config.d_model)
#    print("Pruned encoder layers (highest KE):", enc_prune_idxs)
    print("Pruned decoder layers (highest KE):", dec_prune_idxs)

    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}")
    return model

# --- 6. Entrypoint ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer)
    # --- PRUNING AND CONTINUED FINETUNING ---
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_ke, dec_ke, tokenizer
    )

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Encoder and decoder KE per epoch (dicts)
enc_ke_epochs = [
    {0: 5.480030705793848, 1: 5.5496080641476615, 2: 5.53872952911089, 3: 5.511840214819278, 4: 5.516190859956561, 5: 5.521920965302665, 6: 5.521400574918063, 7: 5.5144495568185485, 8: 5.49060498093659, 9: 5.450951964816921, 10: 5.423030107613378, 11: 5.466182296822421},
    {0: 5.4871524711824815, 1: 5.544603646476314, 2: 5.538442740800246, 3: 5.51321722111612, 4: 5.517586176350432, 5: 5.520688862170813, 6: 5.521329364236796, 7: 5.51392604315056, 8: 5.4896568541256885, 9: 5.451860777928001, 10: 5.422290218796926, 11: 5.473095425204849},
    {0: 5.4869607552042545, 1: 5.544635340402711, 2: 5.538483706060446, 3: 5.5132214298788105, 4: 5.517588706286448, 5: 5.520692724551795, 6: 5.521303620878256, 7: 5.51382915883694, 8: 5.489492775359244, 9: 5.451637857365158, 10: 5.422221080787288, 11: 5.473087342549913},
    {0: 5.486878904306663, 1: 5.5445804946827435, 2: 5.538352051770912, 3: 5.513102901656673, 4: 5.517401096955785, 5: 5.5205126964821005, 6: 5.521155353312222, 7: 5.513736206630491, 8: 5.489534128387019, 9: 5.451726898142929, 10: 5.4222523313800535, 11: 5.473024711809085},
    {0: 5.486847182489791, 1: 5.544575284112175, 2: 5.5384362558148945, 3: 5.513173838381498, 4: 5.517494451324895, 5: 5.520579959311576, 6: 5.5212360184147675, 7: 5.513783328938034, 8: 5.489483414056166, 9: 5.451590445806396, 10: 5.422033144178845, 11: 5.472951911290487},
    {0: 5.486859069680268, 1: 5.5446095003272005, 2: 5.538401910943805, 3: 5.513121429479347, 4: 5.517466274297463, 5: 5.520576953438093, 6: 5.521189136325188, 7: 5.513781391449695, 8: 5.489508598705508, 9: 5.451672037592474, 10: 5.422226219177246, 11: 5.473129866463798}
]
dec_ke_epochs = [
    {0: 5.465231130587175, 1: 5.264995521004111, 2: 5.129532522346007, 3: 4.904289741644123, 4: 5.284679993001292, 5: 5.50792251817332, 6: 5.485444439223263, 7: 5.51603090317343, 8: 5.534246008096697, 9: 5.510181266989484, 10: 5.566446723553958, 11: 5.486786779712739},
    {0: 5.469713579256041, 1: 5.189854344148881, 2: 4.998035216808773, 3: 4.745502117364263, 4: 5.174289006977791, 5: 5.4802458447655455, 6: 5.4423357679005235, 7: 5.488046237465538, 8: 5.5072504567236304, 9: 5.504863533096386, 10: 5.56821925942391, 11: 5.49581682125197},
    {0: 5.469381145927364, 1: 5.188713194414278, 2: 4.996736016985805, 3: 4.743715989033911, 4: 5.17287224979201, 5: 5.479581630989895, 6: 5.441277936341987, 7: 5.487883858630818, 8: 5.507143537392512, 9: 5.505116776213206, 10: 5.5682468636618925, 11: 5.495751358235483},
    {0: 5.469358630307758, 1: 5.188910964336104, 2: 4.996535207024057, 3: 4.743591746301141, 4: 5.172841481580079, 5: 5.4794120374526685, 6: 5.441188570197302, 7: 5.48781734431973, 8: 5.5070663813416285, 9: 5.504971349967345, 10: 5.568174341252742, 11: 5.495965877107081},
    {0: 5.469419138318017, 1: 5.188836002803984, 2: 4.996781077612014, 3: 4.7440564200991675, 4: 5.173168115615844, 5: 5.4795689723605205, 6: 5.441124076389131, 7: 5.487770483834403, 8: 5.506965150378999, 9: 5.504970055534726, 10: 5.568170601981027, 11: 5.495966437657674},
    {0: 5.469227241334461, 1: 5.1887324437640965, 2: 4.996399641491118, 3: 4.743360573450724, 4: 5.172784906568982, 5: 5.479544408434913, 6: 5.441251537232172, 7: 5.487974353971936, 8: 5.507155140922183, 9: 5.505052519298735, 10: 5.56826036453247, 11: 5.495957667032878}
]
layers = list(range(1, 13))  # 1-based layers for x-axis
epochs = [f"Epoch {i+1}" for i in range(6)]

# Convert dicts to lists (layer 1-12)
enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]
dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]

# --- Encoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- Decoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()


In [None]:
# ===========================
# 0. Google Drive Mount
# ===========================
from google.colab import drive
drive.mount('/content/drive')

# ===========================
# 1. Imports and Setup
# ===========================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration, T5TokenizerFast,
    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup
)
from torch.cuda.amp import autocast, GradScaler
from collections import defaultdict
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ===========================
# 2. Load SVAMP Dataset
# ===========================
data_files = {
    "train": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json",
    "test": "/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json"
}
dataset = load_dataset("json", data_files=data_files)

# ===========================
# 3. Preprocessing
# ===========================
def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):
    model_inputs = tokenizer(
        batch["input"], padding="max_length", truncation=True, max_length=max_input_length
    )
    targets = [str(x) for x in batch["label"]]
    target_encodings = tokenizer(
        targets, padding="max_length", truncation=True, max_length=max_target_length
    )
    model_inputs["labels"] = target_encodings["input_ids"]
    return model_inputs

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
train = dataset["train"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["train"].column_names)
dev = dataset["test"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset["test"].column_names)
collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="max_length", max_length=128)
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)
dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

# ===========================
# 4. Knowledge Entropy Utilities
# ===========================
from functools import partial

def register_ke_hooks_t5(model):
    enc_layers = model.encoder.block
    dec_layers = model.decoder.block
    enc_acts = {i: None for i in range(len(enc_layers))}
    dec_acts = {i: None for i in range(len(dec_layers))}
    enc_hooks, dec_hooks = [], []

    for i, layer in enumerate(enc_layers):
        enc_hooks.append(
            layer.layer[1].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)
            )
        )
    for i, layer in enumerate(dec_layers):
        dec_hooks.append(
            layer.layer[2].DenseReluDense.register_forward_hook(
                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)
            )
        )
    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)




def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):
    ke = {}
    for idx, a in acts.items():
        acts[idx] = None  # Always reset, even if skipping
        if a is None:
            continue
        if not torch.isfinite(a).all():
            continue
        if a.numel() == 0 or a.abs().sum() == 0:
            continue
        act = act_fn(a)
        denom = act.sum(dim=-1, keepdim=True)
        denom = torch.where(denom == 0, torch.ones_like(denom), denom)
        probs = act / (denom + eps)
        probs = torch.clamp(probs, min=1e-8)
        if not torch.isfinite(probs).all():
            continue
        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()
        if not torch.isfinite(entropy):
            continue
        ke[idx] = entropy.item()
    return ke



def remove_hooks(hook_sets):
    for hooks, _ in hook_sets:
        for h in hooks: h.remove()

# ===========================
# 5. Pruning Utilities
# ===========================
class SkipFFN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
    def forward(self, hidden_states):
        return hidden_states

def prune_high_ke_ffn(blocks, ke_scores, num_prune=4, hidden_size=768):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)
    return prune_idxs

# ===========================
# 6. Eval Helper
# ===========================
def compute_accuracy(preds, refs):
    correct = 0
    for p, l in zip(preds, refs):
        if p == l:
            correct += 1
    return correct / len(preds) if len(preds) > 0 else 0

def evaluate_model(model, dl, tokenizer, device):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_ids = batch["labels"].clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            preds.extend([p.strip().lower() for p in pred_texts])
            refs.extend([l.strip().lower() for l in ref_texts])
    return compute_accuracy(preds, refs)

# ===========================
# 7. Training + KE Tracking + Pruning
# ===========================
def full_finetuning(train_loader, dev_loader, device, tokenizer):
    print("=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===")
    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scaler = GradScaler()
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)
    last_enc_ke, last_dec_ke = None, None

    for epoch in range(6):
        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)
        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                outputs = model(input_ids=batch['input_ids'].to(device),
                                attention_mask=batch['attention_mask'].to(device),
                                labels=batch['labels'].to(device))
                loss = outputs.loss
                scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
            batch_enc_ke = compute_ke_batch(enc_acts)
            for idx, v in batch_enc_ke.items():
                enc_ke_sum[idx] += v
                enc_ke_count[idx] += 1
            batch_dec_ke = compute_ke_batch(dec_acts)
            for idx, v in batch_dec_ke.items():
                dec_ke_sum[idx] += v
                dec_ke_count[idx] += 1

        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}
        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}")
        print(f"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}")
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Epoch {epoch+1}] Dev Acc: {acc:.4f}")
        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke

    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])
    return model, last_enc_ke, last_dec_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer):
    print("=== Stage 2: Prune (High-KE) & Fine-tuning ===")
#    enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=4, hidden_size=model.config.d_model)
    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=4, hidden_size=model.config.d_model)
#    print("Pruned encoder layers (highest KE):", enc_prune_idxs)
    print("Pruned decoder layers (highest KE):", dec_prune_idxs)

    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            outputs = model(input_ids=batch['input_ids'].to(device),
                            attention_mask=batch['attention_mask'].to(device),
                            labels=batch['labels'].to(device))
            loss = outputs.loss
            loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, tokenizer, device)
        print(f"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}")
    return model

# ===========================
# 8. Main Entrypoint
# ===========================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer)
    model = prune_and_finetuning(
        model, train_loader, dev_loader, device,
        enc_ke, dec_ke, tokenizer
    )

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Provided KE dicts per epoch
enc_ke_epochs = [
    {0: 5.154252872467041, 1: 5.415111303329468, 2: 5.469181814193726, 3: 5.475447826385498, 4: 5.4965128898620605, 5: 5.519717464447021, 6: 5.528017282485962, 7: 5.509926748275757, 8: 5.499344720840454, 9: 5.4804475879669186, 10: 5.460796346664429, 11: 5.484862985610962},
    {0: 5.151100530624389, 1: 5.413294105529785, 2: 5.476608180999756, 3: 5.480278253555298, 4: 5.508586235046387, 5: 5.5349556255340575, 6: 5.54053879737854, 7: 5.51798532485962, 8: 5.511692543029785, 9: 5.501022787094116, 10: 5.478860750198364, 11: 5.49954044342041},
    {0: 5.148409032821656, 1: 5.411761407852173, 2: 5.475491733551025, 3: 5.479352645874023, 4: 5.506951446533203, 5: 5.533953561782837, 6: 5.5362600994110105, 7: 5.5146334266662596, 8: 5.508514547348023, 9: 5.494693784713745, 10: 5.481513671875, 11: 5.50757402420044},
    {0: 5.149230527877807, 1: 5.4128351974487305, 2: 5.476775159835816, 3: 5.480456304550171, 4: 5.5078802585601805, 5: 5.5339813232421875, 6: 5.535842866897583, 7: 5.514435157775879, 8: 5.508073215484619, 9: 5.49350022315979, 10: 5.481070413589477, 11: 5.507593402862549},
    {0: 5.1491956615448, 1: 5.4127289581298825, 2: 5.4760276985168455, 3: 5.479673271179199, 4: 5.507365741729736, 5: 5.533624258041382, 6: 5.535600652694702, 7: 5.5140838718414305, 8: 5.508068170547485, 9: 5.494150772094726, 10: 5.482475533777354, 11: 5.509099026115573},
    {0: 5.148624906539917, 1: 5.412059764862061, 2: 5.476006727218628, 3: 5.479845180511474, 4: 5.507486362457275, 5: 5.534078559875488, 6: 5.535918989181519, 7: 5.514234838485717, 8: 5.507962064743042, 9: 5.493731985092163, 10: 5.482590412606998, 11: 5.5096063711205305}
]
dec_ke_epochs = [
    {0: 5.348595600128174, 1: 4.263790817260742, 2: 4.106602759361267, 3: 3.998788385391235, 4: 4.680020771026611, 5: 5.226834297180176, 6: 5.21035964012146, 7: 5.420144653320312, 8: 5.525527219772339, 9: 5.513277044296265, 10: 5.62511775970459, 11: 5.568444547653198},
    {0: 5.303920812606812, 1: 4.223071842193604, 2: 4.090608291625976, 3: 3.994982166290283, 4: 4.663872623443604, 5: 5.226346616744995, 6: 5.2035440254211425, 7: 5.403556118011474, 8: 5.528957233428955, 9: 5.523081560134887, 10: 5.630525960922241, 11: 5.582965478897095},
    {0: 5.310703840255737, 1: 4.263773174285888, 2: 4.116586332321167, 3: 4.017633848190307, 4: 4.676742868423462, 5: 5.224496612548828, 6: 5.196131429672241, 7: 5.402983207702636, 8: 5.525129270553589, 9: 5.520591449737549, 10: 5.6297976016998295, 11: 5.5853519535064695},
    {0: 5.309150876998902, 1: 4.262697019577026, 2: 4.116720676422119, 3: 4.019005084037781, 4: 4.678860769271851, 5: 5.226104078292846, 6: 5.197585697174072, 7: 5.40531442642212, 8: 5.5262416648864745, 9: 5.5210045337677, 10: 5.630594682693482, 11: 5.586866397857666},
    {0: 5.308763708387103, 1: 4.265620251091159, 2: 4.118220309821927, 3: 4.019376754760742, 4: 4.678985566509013, 5: 5.22458165032523, 6: 5.196559224809919, 7: 5.404032882379026, 8: 5.525934618346545, 9: 5.520692465256672, 10: 5.630939473911208, 11: 5.587003124003508},
    {0: 5.309705539625519, 1: 4.265844345092773, 2: 4.117541984635956, 3: 4.019918616937131, 4: 4.680267382641228, 5: 5.225742593103526, 6: 5.197149743839186, 7: 5.405138190911741, 8: 5.52513566309092, 9: 5.520209808738864, 10: 5.630923173865494, 11: 5.586297920772007}
]

layers = list(range(1, 13))  # 1-based layers for x-axis
epochs = [f"Epoch {i+1}" for i in range(6)]

# Convert dicts to lists (layer 1-12)
enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]
dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]

# --- Encoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
#plt.title("Encoder KE vs Layer", fontsize=18)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- Decoder KE Plot ---
plt.figure(figsize=(10, 6))
for i, epoch in enumerate(epochs):
    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(layers, fontsize=14)
plt.yticks(fontsize=14)
#plt.title("Decoder KE vs Layer", fontsize=18)
plt.grid(True)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tokenizer, max_length=128):
    return tokenizer(examples["premise"], examples["hypothesis"],
                     truncation=True, padding="max_length", max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=3
    ).to(device)
    model.gradient_checkpointing_enable()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx] += v
                ke_counts[idx] += 1

        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MNLI Acc: {acc:.4f}")
    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}")
    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.Adam(
        list(model.classifier.parameters()) +
        [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = load_dataset("glue", "mnli")
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train_ds = dataset["train"].shuffle(seed).select(range(10000))  # smaller for speed
    dev_ds = dataset["validation_matched"]

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["premise", "hypothesis"])\
                    .rename_column("label", "labels")

    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True,
                     remove_columns=["premise", "hypothesis"])\
                .rename_column("label", "labels")

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE values per epoch
ke_epochs = {
    1: {0: 3.7403, 1: 3.8396, 2: 3.8184, 3: 3.7880, 4: 3.7816, 5: 3.7467, 6: 3.7648, 7: 3.7296, 8: 3.7192, 9: 3.6667, 10: 3.5540, 11: 3.5193},
    2: {0: 3.7418, 1: 3.8355, 2: 3.8190, 3: 3.7905, 4: 3.7855, 5: 3.7613, 6: 3.7775, 7: 3.7418, 8: 3.7250, 9: 3.6640, 10: 3.5458, 11: 3.5028},
    3: {0: 3.7434, 1: 3.8356, 2: 3.8193, 3: 3.7932, 4: 3.7921, 5: 3.7653, 6: 3.7784, 7: 3.7541, 8: 3.7352, 9: 3.6658, 10: 3.5442, 11: 3.5088},
    4: {0: 3.7441, 1: 3.8372, 2: 3.8191, 3: 3.7979, 4: 3.7927, 5: 3.7678, 6: 3.7775, 7: 3.7542, 8: 3.7398, 9: 3.6725, 10: 3.5492, 11: 3.5134},
    5: {0: 3.7451, 1: 3.8349, 2: 3.8234, 3: 3.8037, 4: 3.7976, 5: 3.7695, 6: 3.7790, 7: 3.7539, 8: 3.7457, 9: 3.6797, 10: 3.5499, 11: 3.5118},
    6: {0: 3.7452, 1: 3.8372, 2: 3.8257, 3: 3.8062, 4: 3.8008, 5: 3.7704, 6: 3.7808, 7: 3.7537, 8: 3.7452, 9: 3.6729, 10: 3.5400, 11: 3.5013}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, ke in ke_epochs.items():
    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12
    values = list(ke.values())
    plt.plot(layers, values, marker='o', label=f"Epoch {epoch}")

#plt.title("Knowledge Entropy vs Layer Index", fontsize=16)
plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MRPC Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "mrpc", split="train").shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "mrpc", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Knowledge Entropy values per epoch
ke_epochs = [
    {0: 3.239791645050049, 1: 3.2986808376312258, 2: 3.253356773376465, 3: 3.228304588317871, 4: 3.204816005706787, 5: 3.1601281089782716, 6: 3.167900978088379, 7: 3.14154211807251, 8: 3.136232151031494, 9: 3.108181224822998, 10: 3.070255741119385, 11: 3.150283061981201},
    {0: 3.2390173606872557, 1: 3.2963816051483152, 2: 3.2457099170684813, 3: 3.224952474594116, 4: 3.199368993759155, 5: 3.1539372539520265, 6: 3.149820531845093, 7: 3.1081827564239504, 8: 3.0931982421875, 9: 3.0528202323913574, 10: 3.028363660812378, 11: 3.050001382827759},
    {0: 3.239009340286255, 1: 3.2983984470367433, 2: 3.2463859825134276, 3: 3.2241379623413087, 4: 3.1976428546905518, 5: 3.1447392463684083, 6: 3.1279000110626223, 7: 3.0747772026062012, 8: 3.0223362197875976, 9: 2.9473570098876953, 10: 2.87524440574646, 11: 2.8674370727539062},
    {0: 3.238827512741089, 1: 3.2974642314910887, 2: 3.2431242713928223, 3: 3.2216903343200682, 4: 3.199512315750122, 5: 3.1563696422576903, 6: 3.134328540802002, 7: 3.0738258819580078, 8: 2.99842280960083, 9: 2.912938259124756, 10: 2.8102419834136962, 11: 2.771777828216553},
    {0: 3.2383913173675536, 1: 3.296243072509766, 2: 3.242586082458496, 3: 3.221456657409668, 4: 3.2004083824157714, 5: 3.152312059402466, 6: 3.1258499088287355, 7: 3.064895736694336, 8: 2.980252359390259, 9: 2.860691904067993, 10: 2.749246000289917, 11: 2.710493507385254},
    {0: 3.238114908218384, 1: 3.2962895565032957, 2: 3.243076961517334, 3: 3.221760944366455, 4: 3.2005157108306883, 5: 3.1507289390563966, 6: 3.123993368148804, 7: 3.0618533477783205, 8: 2.9700901947021485, 9: 2.843407918930054, 10: 2.727255252838135, 11: 2.6856814765930177}
]

# Plotting
plt.figure(figsize=(10, 6))
for i, epoch_ke in enumerate(ke_epochs):
    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12
    values = list(epoch_ke.values())
    plt.plot(layers, values, label=f"Epoch {i+1}", marker='o')

plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs Layers", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune SST-2 Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] SST-2 Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] SST-2 Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "sst2", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "sst2", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE values for each epoch
ke_by_epoch = {
    1: [2.9651833793640137, 3.089046128463745, 3.0924095928192137, 3.0681774433135987, 3.0340894985198976, 2.9792041194915773, 2.9694000019073488, 2.9185178718566895, 2.9200110771179197, 2.82025255355835, 2.667954217147827, 2.6612180709838866],
    2: [2.965893549346924, 3.085789897155762, 3.092639580154419, 3.078977996826172, 3.0649560813903807, 3.0321216072082517, 3.018477940368652, 2.9576354961395266, 2.911316421508789, 2.782052672576904, 2.6040575157165526, 2.5998096366882324],
    3: [2.9650946979522703, 3.084409861755371, 3.0894264568328857, 3.077181778717041, 3.066806471252441, 3.0330316158294677, 3.026228702163696, 2.9687501598358152, 2.918505425262451, 2.781764380264282, 2.6073847118377684, 2.5956354915618896],
    4: [2.9647007961273193, 3.0866590961456297, 3.088959024810791, 3.076308888244629, 3.0617130882263184, 3.0239853103637695, 3.0255811374664305, 2.9724936283111574, 2.91553058013916, 2.779527521133423, 2.608109250640869, 2.599828709793091],
    5: [2.9640205013275147, 3.089625425720215, 3.0891447685241697, 3.0789716785430907, 3.066036195373535, 3.0240665218353273, 3.0201424140930175, 2.9597047756195067, 2.888725273895264, 2.7510012699127198, 2.584951690673828, 2.5864554821014405],
    6: [2.963668116760254, 3.090019404220581, 3.089617932128906, 3.0806209617614746, 3.0694873622894288, 3.0306819889068604, 3.0209427997589113, 2.956924412536621, 2.8712899379730223, 2.7296667186737062, 2.57259878578186, 2.5810291679382322]
}

plt.figure(figsize=(10, 6))
for epoch, ke_values in ke_by_epoch.items():
    plt.plot(range(1,13), ke_values, marker='o', label=f'Epoch {epoch}')

#plt.title("Knowledge Entropy vs. Layer", fontsize=16)
plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Approx. Knowledge Entropy", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )


from sklearn.metrics import matthews_corrcoef

def evaluate_model(model, dl, device):
    model.eval()
    acc_metric = evaluate.load("accuracy")
    mcc_metric = evaluate.load("matthews_correlation")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    acc = acc_metric.compute(predictions=preds, references=labs)["accuracy"]
    mcc = mcc_metric.compute(predictions=preds, references=labs)["matthews_correlation"]
    return acc, mcc



# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

        acc, mcc = evaluate_model(model, dev_loader, device)
        print(f"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}")


    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()

        acc, mcc = evaluate_model(model, dev_loader, device)
        print(f"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}")


    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        acc, mcc = evaluate_model(model, dev_loader, device)
        print(f"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "cola", split="train").shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "cola", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Knowledge Entropy values over 6 epochs
ke_values = [
    {0: 2.9448, 1: 3.0885, 2: 3.1181, 3: 3.0932, 4: 3.0441, 5: 2.9380, 6: 2.9606, 7: 2.9447, 8: 2.9687, 9: 2.9775, 10: 2.9016, 11: 2.9350},
    {0: 2.9439, 1: 3.0961, 2: 3.1100, 3: 3.0816, 4: 3.0435, 5: 2.9906, 6: 3.0009, 7: 2.9850, 8: 2.9817, 9: 2.9741, 10: 2.8645, 11: 2.8777},
    {0: 2.9447, 1: 3.0925, 2: 3.1075, 3: 3.0703, 4: 3.0372, 5: 2.9933, 6: 3.0018, 7: 2.9899, 8: 2.9857, 9: 2.9657, 10: 2.8410, 11: 2.8267},
    {0: 2.9449, 1: 3.0911, 2: 3.1028, 3: 3.0704, 4: 3.0389, 5: 2.9948, 6: 2.9996, 7: 2.9699, 8: 2.9628, 9: 2.9423, 10: 2.8048, 11: 2.7881},
    {0: 2.9457, 1: 3.0933, 2: 3.1071, 3: 3.0762, 4: 3.0477, 5: 3.0011, 6: 3.0038, 7: 2.9738, 8: 2.9548, 9: 2.9251, 10: 2.7719, 11: 2.7567},
    {0: 2.9456, 1: 3.0939, 2: 3.1071, 3: 3.0768, 4: 3.0461, 5: 3.0022, 6: 3.0023, 7: 2.9735, 8: 2.9553, 9: 2.9216, 10: 2.7640, 11: 2.7494},
]

plt.figure(figsize=(10, 6))
for i, epoch_ke in enumerate(ke_values, start=1):
    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12
    entropy = list(epoch_ke.values())
    plt.plot(layers, entropy, marker='o', label=f"Epoch {i}")

plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs. Layer Index (CoLA)", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['question'],
               examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QNLI Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "qnli", split="train").shuffle(seed).select(range(2000))
    dev_ds   = load_dataset("glue", "qnli", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["question","sentence","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["question","sentence","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Knowledge Entropy (KE) values across layers for each epoch
ke_epochs = {
    1: {0: 3.2304, 1: 3.2924, 2: 3.2409, 3: 3.1981, 4: 3.1778, 5: 3.1335, 6: 3.1346, 7: 3.0527, 8: 3.0132, 9: 2.9415, 10: 2.8503, 11: 2.8969},
    2: {0: 3.2314, 1: 3.2867, 2: 3.2390, 3: 3.2144, 4: 3.1975, 5: 3.1637, 6: 3.1601, 7: 3.0645, 8: 2.9672, 9: 2.8309, 10: 2.6973, 11: 2.6584},
    3: {0: 3.2336, 1: 3.2851, 2: 3.2359, 3: 3.2133, 4: 3.2056, 5: 3.1733, 6: 3.1724, 7: 3.0849, 8: 2.9905, 9: 2.8596, 10: 2.7224, 11: 2.6792},
    4: {0: 3.2347, 1: 3.2857, 2: 3.2389, 3: 3.2136, 4: 3.2055, 5: 3.1750, 6: 3.1752, 7: 3.0889, 8: 2.9981, 9: 2.8731, 10: 2.7439, 11: 2.6969},
    5: {0: 3.2345, 1: 3.2861, 2: 3.2386, 3: 3.2143, 4: 3.2060, 5: 3.1730, 6: 3.1725, 7: 3.0911, 8: 2.9987, 9: 2.8678, 10: 2.7374, 11: 2.6870},
    6: {0: 3.2349, 1: 3.2869, 2: 3.2387, 3: 3.2147, 4: 3.2044, 5: 3.1738, 6: 3.1718, 7: 3.0958, 8: 3.0050, 9: 2.8763, 10: 2.7463, 11: 2.6961}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, values in ke_epochs.items():
    layers = [l + 1 for l in values.keys()]  # shift layer indices to 1–12
    entropy = list(values.values())
    plt.plot(layers, entropy, marker='o', label=f"Epoch {epoch}")

plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs. Layer Index", fontsize=16)
plt.legend(fontsize=12)
plt.grid(True)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['question1'],
        examples['question2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QQP Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "qqp", split="train").shuffle(seed).select(range(3000))
    dev_ds   = load_dataset("glue", "qqp", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["question1","question2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["question1","question2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE data per epoch
ke_epochs = {
    1: {0: 3.1468, 1: 3.2210, 2: 3.1869, 3: 3.1624, 4: 3.1201, 5: 3.0795, 6: 3.1073, 7: 3.0597, 8: 3.0786, 9: 3.0529, 10: 2.9826, 11: 2.9888},
    2: {0: 3.1473, 1: 3.2196, 2: 3.1859, 3: 3.1655, 4: 3.1310, 5: 3.1000, 6: 3.1146, 7: 3.0777, 8: 3.0792, 9: 3.0537, 10: 2.9414, 11: 2.9214},
    3: {0: 3.1469, 1: 3.2201, 2: 3.1880, 3: 3.1641, 4: 3.1386, 5: 3.1055, 6: 3.1073, 7: 3.0607, 8: 3.0420, 9: 2.9884, 10: 2.8447, 11: 2.8140},
    4: {0: 3.1465, 1: 3.2205, 2: 3.1869, 3: 3.1643, 4: 3.1412, 5: 3.1092, 6: 3.1093, 7: 3.0606, 8: 3.0383, 9: 2.9692, 10: 2.8255, 11: 2.7930},
    5: {0: 3.1463, 1: 3.2201, 2: 3.1878, 3: 3.1660, 4: 3.1404, 5: 3.1064, 6: 3.1102, 7: 3.0624, 8: 3.0388, 9: 2.9652, 10: 2.8183, 11: 2.7961},
    6: {0: 3.1460, 1: 3.2191, 2: 3.1884, 3: 3.1685, 4: 3.1425, 5: 3.1084, 6: 3.1098, 7: 3.0626, 8: 3.0334, 9: 2.9510, 10: 2.8009, 11: 2.7761},
}

# Plotting
plt.figure(figsize=(10, 6))
for epoch, ke in ke_epochs.items():
    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12
    values = list(ke.values())
    plt.plot(layers, values, marker='o', label=f"Epoch {epoch}")

#plt.title("Knowledge Entropy vs Layers", fontsize=16)
plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune RTE Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "rte", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "rte", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE data for each epoch
ke_data = {
    1: {0: 3.234637494270618, 1: 3.2954297256775393, 2: 3.2472059474541592, 3: 3.2122703920572233, 4: 3.1996606771762552, 5: 3.166536705616193, 6: 3.192006678917469, 7: 3.1797323601368146, 8: 3.1996014416217804, 9: 3.198551667042268, 10: 3.1776862190319943, 11: 3.233222233179288},
    2: {0: 3.2364144898377933, 1: 3.292912809512554, 2: 3.241034468779197, 3: 3.214976036395782, 4: 3.198220124611488, 5: 3.16093972325325, 6: 3.176389996057902, 7: 3.153661161661148, 8: 3.156559161650829, 9: 3.124322636769368, 10: 3.0624067340141687, 11: 3.0741044863676414},
    3: {0: 3.2363225794755497, 1: 3.2945617039998374, 2: 3.2429233766519108, 3: 3.214366410023127, 4: 3.2025416669173117, 5: 3.16937870092881, 6: 3.180643222270868, 7: 3.16009650627772, 8: 3.1594576827990704, 9: 3.1243840715824027, 10: 3.053374951466536, 11: 3.0580264222927585},
    4: {0: 3.2360405830236583, 1: 3.2945619431825786, 2: 3.2415615954460244, 3: 3.2135055791109037, 4: 3.2035804826479692, 5: 3.1748380936109104, 6: 3.1875738394566073, 7: 3.1665822023000474, 8: 3.1637751505925107, 9: 3.1298167407512665, 10: 3.052433254627081, 11: 3.038477917512258},
    5: {0: 3.2361522951187234, 1: 3.294039017114884, 2: 3.241480989333911, 3: 3.2140066035282917, 4: 3.204068192304709, 5: 3.175252984731625, 6: 3.1887962214457684, 7: 3.1690831688734202, 8: 3.1620407494214864, 9: 3.1216949025789895, 10: 3.035130114127428, 11: 3.0117857922346163},
    6: {0: 3.2358735440633235, 1: 3.2946181755799513, 2: 3.2427005072434745, 3: 3.213492139791831, 4: 3.204735367726057, 5: 3.174306563077829, 6: 3.1876255694108133, 7: 3.167302029255109, 8: 3.1590061837281938, 9: 3.115201425093871, 10: 3.0225293246599345, 11: 2.994033750051107}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, layer_data in ke_data.items():
    layers = [l + 1 for l in layer_data.keys()]  # shift layer indices to 1–12
    ke_values = list(layer_data.values())
    plt.plot(layers, ke_values, marker='o', label=f"Epoch {epoch}")

plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs. Layer", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("pearsonr")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            # Regression head: output is [B, 1]
            pred = out.logits.view(-1).cpu().numpy()
            preds.extend(pred)
    return metric.compute(predictions=preds, references=labs)["pearsonr"]

# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=1  # num_labels=1 for regression!
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    pearson = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune STS-B Pearson: {pearson:.4f}")

    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))
            out.loss.backward()
            opt.step()
            sched.step()
        pearson = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        pearson = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "stsb", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "stsb", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Provided KE values
ke_epochs = {
    1: {0: 3.1216782207117624, 1: 3.205926341373831, 2: 3.1831439436725515, 3: 3.1425496576890164, 4: 3.108628351930451,
        5: 3.0245226891879744, 6: 3.0113194900693085, 7: 2.96199728318481, 8: 2.9216685882032496, 9: 2.9011312645234386,
        10: 2.7949637216056007, 11: 2.8036840522405337},
    2: {0: 3.122385612946724, 1: 3.2053615688780255, 2: 3.1827276502431516, 3: 3.14758318588034, 4: 3.128111186047423,
        5: 3.0808556003597083, 6: 3.0710045684527953, 7: 3.012045979334018, 8: 2.9473394376014634, 9: 2.890920449032737,
        10: 2.706535710744632, 11: 2.694957201271959},
    3: {0: 3.1233545473785824, 1: 3.2060421793119303, 2: 3.180608749721246, 3: 3.149725641096749, 4: 3.1337831318792944,
        5: 3.087708413518022, 6: 3.0823230342175267, 7: 3.0231315847563978, 8: 2.9449381267908383, 9: 2.8736855148108513,
        10: 2.6799341229636413, 11: 2.6526868296929624},
    4: {0: 3.122849219697573, 1: 3.208289318190828, 2: 3.1786043458256836, 3: 3.152634641225547, 4: 3.140414782127518,
        5: 3.0983164837695294, 6: 3.087455921942402, 7: 3.0256239744486164, 8: 2.950759364434509, 9: 2.883653296548899,
        10: 2.686351388817205, 11: 2.6519999305130875},
    5: {0: 3.1224169508969832, 1: 3.2059260482416696, 2: 3.17878786678606, 3: 3.156504187365069, 4: 3.1431414029860196,
        5: 3.0961787657545403, 6: 3.0855182536951524, 7: 3.0326405927766844, 8: 2.9571537298353725, 9: 2.8834909653298877,
        10: 2.6865991667348252, 11: 2.644667492788259},
    6: {0: 3.122367330651953, 1: 3.20639664672512, 2: 3.179138040343644, 3: 3.157353766604491, 4: 3.1413243549755454,
        5: 3.0952708057965954, 6: 3.0864931734612986, 7: 3.033865492267635, 8: 2.9587372055637324, 9: 2.8831359905725726,
        10: 2.6869357043413524, 11: 2.6412745341140473}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, ke in ke_epochs.items():
    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12
    values = list(ke.values())
    plt.plot(layers, values, marker='o', label=f'Epoch {epoch}')

plt.xlabel('Layer Index', fontsize=16)
plt.ylabel('Knowledge Entropy', fontsize=16)
#plt.title('Knowledge Entropy vs Layer (fz=16)', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()
