In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate

In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tokenizer, max_length=128):
    return tokenizer(examples["premise"], examples["hypothesis"],
                     truncation=True, padding="max_length", max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=3
    ).to(device)
    model.gradient_checkpointing_enable()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx] += v
                ke_counts[idx] += 1

        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MNLI Acc: {acc:.4f}")
    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}")
    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.Adam(
        list(model.classifier.parameters()) +
        [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = load_dataset("glue", "mnli")
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train_ds = dataset["train"].shuffle(seed).select(range(10000))  # smaller for speed
    dev_ds = dataset["validation_matched"]

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["premise", "hypothesis"])\
                    .rename_column("label", "labels")

    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True,
                     remove_columns=["premise", "hypothesis"])\
                .rename_column("label", "labels")

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE values per epoch
ke_epochs = {
    1: {0: 3.7403, 1: 3.8396, 2: 3.8184, 3: 3.7880, 4: 3.7816, 5: 3.7467, 6: 3.7648, 7: 3.7296, 8: 3.7192, 9: 3.6667, 10: 3.5540, 11: 3.5193},
    2: {0: 3.7418, 1: 3.8355, 2: 3.8190, 3: 3.7905, 4: 3.7855, 5: 3.7613, 6: 3.7775, 7: 3.7418, 8: 3.7250, 9: 3.6640, 10: 3.5458, 11: 3.5028},
    3: {0: 3.7434, 1: 3.8356, 2: 3.8193, 3: 3.7932, 4: 3.7921, 5: 3.7653, 6: 3.7784, 7: 3.7541, 8: 3.7352, 9: 3.6658, 10: 3.5442, 11: 3.5088},
    4: {0: 3.7441, 1: 3.8372, 2: 3.8191, 3: 3.7979, 4: 3.7927, 5: 3.7678, 6: 3.7775, 7: 3.7542, 8: 3.7398, 9: 3.6725, 10: 3.5492, 11: 3.5134},
    5: {0: 3.7451, 1: 3.8349, 2: 3.8234, 3: 3.8037, 4: 3.7976, 5: 3.7695, 6: 3.7790, 7: 3.7539, 8: 3.7457, 9: 3.6797, 10: 3.5499, 11: 3.5118},
    6: {0: 3.7452, 1: 3.8372, 2: 3.8257, 3: 3.8062, 4: 3.8008, 5: 3.7704, 6: 3.7808, 7: 3.7537, 8: 3.7452, 9: 3.6729, 10: 3.5400, 11: 3.5013}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, ke in ke_epochs.items():
    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12
    values = list(ke.values())
    plt.plot(layers, values, marker='o', label=f"Epoch {epoch}")

#plt.title("Knowledge Entropy vs Layer Index", fontsize=16)
plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MRPC Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "mrpc", split="train").shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "mrpc", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Knowledge Entropy values per epoch
ke_epochs = [
    {0: 3.239791645050049, 1: 3.2986808376312258, 2: 3.253356773376465, 3: 3.228304588317871, 4: 3.204816005706787, 5: 3.1601281089782716, 6: 3.167900978088379, 7: 3.14154211807251, 8: 3.136232151031494, 9: 3.108181224822998, 10: 3.070255741119385, 11: 3.150283061981201},
    {0: 3.2390173606872557, 1: 3.2963816051483152, 2: 3.2457099170684813, 3: 3.224952474594116, 4: 3.199368993759155, 5: 3.1539372539520265, 6: 3.149820531845093, 7: 3.1081827564239504, 8: 3.0931982421875, 9: 3.0528202323913574, 10: 3.028363660812378, 11: 3.050001382827759},
    {0: 3.239009340286255, 1: 3.2983984470367433, 2: 3.2463859825134276, 3: 3.2241379623413087, 4: 3.1976428546905518, 5: 3.1447392463684083, 6: 3.1279000110626223, 7: 3.0747772026062012, 8: 3.0223362197875976, 9: 2.9473570098876953, 10: 2.87524440574646, 11: 2.8674370727539062},
    {0: 3.238827512741089, 1: 3.2974642314910887, 2: 3.2431242713928223, 3: 3.2216903343200682, 4: 3.199512315750122, 5: 3.1563696422576903, 6: 3.134328540802002, 7: 3.0738258819580078, 8: 2.99842280960083, 9: 2.912938259124756, 10: 2.8102419834136962, 11: 2.771777828216553},
    {0: 3.2383913173675536, 1: 3.296243072509766, 2: 3.242586082458496, 3: 3.221456657409668, 4: 3.2004083824157714, 5: 3.152312059402466, 6: 3.1258499088287355, 7: 3.064895736694336, 8: 2.980252359390259, 9: 2.860691904067993, 10: 2.749246000289917, 11: 2.710493507385254},
    {0: 3.238114908218384, 1: 3.2962895565032957, 2: 3.243076961517334, 3: 3.221760944366455, 4: 3.2005157108306883, 5: 3.1507289390563966, 6: 3.123993368148804, 7: 3.0618533477783205, 8: 2.9700901947021485, 9: 2.843407918930054, 10: 2.727255252838135, 11: 2.6856814765930177}
]

# Plotting
plt.figure(figsize=(10, 6))
for i, epoch_ke in enumerate(ke_epochs):
    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12
    values = list(epoch_ke.values())
    plt.plot(layers, values, label=f"Epoch {i+1}", marker='o')

plt.xlabel("Layer", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs Layers", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune SST-2 Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] SST-2 Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] SST-2 Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "sst2", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "sst2", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE values for each epoch
ke_by_epoch = {
    1: [2.9651833793640137, 3.089046128463745, 3.0924095928192137, 3.0681774433135987, 3.0340894985198976, 2.9792041194915773, 2.9694000019073488, 2.9185178718566895, 2.9200110771179197, 2.82025255355835, 2.667954217147827, 2.6612180709838866],
    2: [2.965893549346924, 3.085789897155762, 3.092639580154419, 3.078977996826172, 3.0649560813903807, 3.0321216072082517, 3.018477940368652, 2.9576354961395266, 2.911316421508789, 2.782052672576904, 2.6040575157165526, 2.5998096366882324],
    3: [2.9650946979522703, 3.084409861755371, 3.0894264568328857, 3.077181778717041, 3.066806471252441, 3.0330316158294677, 3.026228702163696, 2.9687501598358152, 2.918505425262451, 2.781764380264282, 2.6073847118377684, 2.5956354915618896],
    4: [2.9647007961273193, 3.0866590961456297, 3.088959024810791, 3.076308888244629, 3.0617130882263184, 3.0239853103637695, 3.0255811374664305, 2.9724936283111574, 2.91553058013916, 2.779527521133423, 2.608109250640869, 2.599828709793091],
    5: [2.9640205013275147, 3.089625425720215, 3.0891447685241697, 3.0789716785430907, 3.066036195373535, 3.0240665218353273, 3.0201424140930175, 2.9597047756195067, 2.888725273895264, 2.7510012699127198, 2.584951690673828, 2.5864554821014405],
    6: [2.963668116760254, 3.090019404220581, 3.089617932128906, 3.0806209617614746, 3.0694873622894288, 3.0306819889068604, 3.0209427997589113, 2.956924412536621, 2.8712899379730223, 2.7296667186737062, 2.57259878578186, 2.5810291679382322]
}

plt.figure(figsize=(10, 6))
for epoch, ke_values in ke_by_epoch.items():
    plt.plot(range(1,13), ke_values, marker='o', label=f'Epoch {epoch}')

#plt.title("Knowledge Entropy vs. Layer", fontsize=16)
plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Approx. Knowledge Entropy", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )


from sklearn.metrics import matthews_corrcoef

def evaluate_model(model, dl, device):
    model.eval()
    acc_metric = evaluate.load("accuracy")
    mcc_metric = evaluate.load("matthews_correlation")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    acc = acc_metric.compute(predictions=preds, references=labs)["accuracy"]
    mcc = mcc_metric.compute(predictions=preds, references=labs)["matthews_correlation"]
    return acc, mcc



# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

        acc, mcc = evaluate_model(model, dev_loader, device)
        print(f"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}")


    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()

        acc, mcc = evaluate_model(model, dev_loader, device)
        print(f"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}")


    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        acc, mcc = evaluate_model(model, dev_loader, device)
        print(f"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "cola", split="train").shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "cola", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Knowledge Entropy values over 6 epochs
ke_values = [
    {0: 2.9448, 1: 3.0885, 2: 3.1181, 3: 3.0932, 4: 3.0441, 5: 2.9380, 6: 2.9606, 7: 2.9447, 8: 2.9687, 9: 2.9775, 10: 2.9016, 11: 2.9350},
    {0: 2.9439, 1: 3.0961, 2: 3.1100, 3: 3.0816, 4: 3.0435, 5: 2.9906, 6: 3.0009, 7: 2.9850, 8: 2.9817, 9: 2.9741, 10: 2.8645, 11: 2.8777},
    {0: 2.9447, 1: 3.0925, 2: 3.1075, 3: 3.0703, 4: 3.0372, 5: 2.9933, 6: 3.0018, 7: 2.9899, 8: 2.9857, 9: 2.9657, 10: 2.8410, 11: 2.8267},
    {0: 2.9449, 1: 3.0911, 2: 3.1028, 3: 3.0704, 4: 3.0389, 5: 2.9948, 6: 2.9996, 7: 2.9699, 8: 2.9628, 9: 2.9423, 10: 2.8048, 11: 2.7881},
    {0: 2.9457, 1: 3.0933, 2: 3.1071, 3: 3.0762, 4: 3.0477, 5: 3.0011, 6: 3.0038, 7: 2.9738, 8: 2.9548, 9: 2.9251, 10: 2.7719, 11: 2.7567},
    {0: 2.9456, 1: 3.0939, 2: 3.1071, 3: 3.0768, 4: 3.0461, 5: 3.0022, 6: 3.0023, 7: 2.9735, 8: 2.9553, 9: 2.9216, 10: 2.7640, 11: 2.7494},
]

plt.figure(figsize=(10, 6))
for i, epoch_ke in enumerate(ke_values, start=1):
    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12
    entropy = list(epoch_ke.values())
    plt.plot(layers, entropy, marker='o', label=f"Epoch {i}")

plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs. Layer Index (CoLA)", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()


# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['question'],
               examples['sentence'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QNLI Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke


def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "qnli", split="train").shuffle(seed).select(range(2000))
    dev_ds   = load_dataset("glue", "qnli", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["question","sentence","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["question","sentence","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Knowledge Entropy (KE) values across layers for each epoch
ke_epochs = {
    1: {0: 3.2304, 1: 3.2924, 2: 3.2409, 3: 3.1981, 4: 3.1778, 5: 3.1335, 6: 3.1346, 7: 3.0527, 8: 3.0132, 9: 2.9415, 10: 2.8503, 11: 2.8969},
    2: {0: 3.2314, 1: 3.2867, 2: 3.2390, 3: 3.2144, 4: 3.1975, 5: 3.1637, 6: 3.1601, 7: 3.0645, 8: 2.9672, 9: 2.8309, 10: 2.6973, 11: 2.6584},
    3: {0: 3.2336, 1: 3.2851, 2: 3.2359, 3: 3.2133, 4: 3.2056, 5: 3.1733, 6: 3.1724, 7: 3.0849, 8: 2.9905, 9: 2.8596, 10: 2.7224, 11: 2.6792},
    4: {0: 3.2347, 1: 3.2857, 2: 3.2389, 3: 3.2136, 4: 3.2055, 5: 3.1750, 6: 3.1752, 7: 3.0889, 8: 2.9981, 9: 2.8731, 10: 2.7439, 11: 2.6969},
    5: {0: 3.2345, 1: 3.2861, 2: 3.2386, 3: 3.2143, 4: 3.2060, 5: 3.1730, 6: 3.1725, 7: 3.0911, 8: 2.9987, 9: 2.8678, 10: 2.7374, 11: 2.6870},
    6: {0: 3.2349, 1: 3.2869, 2: 3.2387, 3: 3.2147, 4: 3.2044, 5: 3.1738, 6: 3.1718, 7: 3.0958, 8: 3.0050, 9: 2.8763, 10: 2.7463, 11: 2.6961}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, values in ke_epochs.items():
    layers = [l + 1 for l in values.keys()]  # shift layer indices to 1–12
    entropy = list(values.values())
    plt.plot(layers, entropy, marker='o', label=f"Epoch {epoch}")

plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs. Layer Index", fontsize=16)
plt.legend(fontsize=12)
plt.grid(True)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['question1'],
        examples['question2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QQP Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "qqp", split="train").shuffle(seed).select(range(3000))
    dev_ds   = load_dataset("glue", "qqp", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["question1","question2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["question1","question2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE data per epoch
ke_epochs = {
    1: {0: 3.1468, 1: 3.2210, 2: 3.1869, 3: 3.1624, 4: 3.1201, 5: 3.0795, 6: 3.1073, 7: 3.0597, 8: 3.0786, 9: 3.0529, 10: 2.9826, 11: 2.9888},
    2: {0: 3.1473, 1: 3.2196, 2: 3.1859, 3: 3.1655, 4: 3.1310, 5: 3.1000, 6: 3.1146, 7: 3.0777, 8: 3.0792, 9: 3.0537, 10: 2.9414, 11: 2.9214},
    3: {0: 3.1469, 1: 3.2201, 2: 3.1880, 3: 3.1641, 4: 3.1386, 5: 3.1055, 6: 3.1073, 7: 3.0607, 8: 3.0420, 9: 2.9884, 10: 2.8447, 11: 2.8140},
    4: {0: 3.1465, 1: 3.2205, 2: 3.1869, 3: 3.1643, 4: 3.1412, 5: 3.1092, 6: 3.1093, 7: 3.0606, 8: 3.0383, 9: 2.9692, 10: 2.8255, 11: 2.7930},
    5: {0: 3.1463, 1: 3.2201, 2: 3.1878, 3: 3.1660, 4: 3.1404, 5: 3.1064, 6: 3.1102, 7: 3.0624, 8: 3.0388, 9: 2.9652, 10: 2.8183, 11: 2.7961},
    6: {0: 3.1460, 1: 3.2191, 2: 3.1884, 3: 3.1685, 4: 3.1425, 5: 3.1084, 6: 3.1098, 7: 3.0626, 8: 3.0334, 9: 2.9510, 10: 2.8009, 11: 2.7761},
}

# Plotting
plt.figure(figsize=(10, 6))
for epoch, ke in ke_epochs.items():
    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12
    values = list(ke.values())
    plt.plot(layers, values, marker='o', label=f"Epoch {epoch}")

#plt.title("Knowledge Entropy vs Layers", fontsize=16)
plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune RTE Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "rte", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "rte", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# KE data for each epoch
ke_data = {
    1: {0: 3.234637494270618, 1: 3.2954297256775393, 2: 3.2472059474541592, 3: 3.2122703920572233, 4: 3.1996606771762552, 5: 3.166536705616193, 6: 3.192006678917469, 7: 3.1797323601368146, 8: 3.1996014416217804, 9: 3.198551667042268, 10: 3.1776862190319943, 11: 3.233222233179288},
    2: {0: 3.2364144898377933, 1: 3.292912809512554, 2: 3.241034468779197, 3: 3.214976036395782, 4: 3.198220124611488, 5: 3.16093972325325, 6: 3.176389996057902, 7: 3.153661161661148, 8: 3.156559161650829, 9: 3.124322636769368, 10: 3.0624067340141687, 11: 3.0741044863676414},
    3: {0: 3.2363225794755497, 1: 3.2945617039998374, 2: 3.2429233766519108, 3: 3.214366410023127, 4: 3.2025416669173117, 5: 3.16937870092881, 6: 3.180643222270868, 7: 3.16009650627772, 8: 3.1594576827990704, 9: 3.1243840715824027, 10: 3.053374951466536, 11: 3.0580264222927585},
    4: {0: 3.2360405830236583, 1: 3.2945619431825786, 2: 3.2415615954460244, 3: 3.2135055791109037, 4: 3.2035804826479692, 5: 3.1748380936109104, 6: 3.1875738394566073, 7: 3.1665822023000474, 8: 3.1637751505925107, 9: 3.1298167407512665, 10: 3.052433254627081, 11: 3.038477917512258},
    5: {0: 3.2361522951187234, 1: 3.294039017114884, 2: 3.241480989333911, 3: 3.2140066035282917, 4: 3.204068192304709, 5: 3.175252984731625, 6: 3.1887962214457684, 7: 3.1690831688734202, 8: 3.1620407494214864, 9: 3.1216949025789895, 10: 3.035130114127428, 11: 3.0117857922346163},
    6: {0: 3.2358735440633235, 1: 3.2946181755799513, 2: 3.2427005072434745, 3: 3.213492139791831, 4: 3.204735367726057, 5: 3.174306563077829, 6: 3.1876255694108133, 7: 3.167302029255109, 8: 3.1590061837281938, 9: 3.115201425093871, 10: 3.0225293246599345, 11: 2.994033750051107}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, layer_data in ke_data.items():
    layers = [l + 1 for l in layer_data.keys()]  # shift layer indices to 1–12
    ke_values = list(layer_data.values())
    plt.plot(layers, ke_values, marker='o', label=f"Epoch {epoch}")

plt.xlabel("Layer Index", fontsize=16)
plt.ylabel("Knowledge Entropy", fontsize=16)
#plt.title("Knowledge Entropy vs. Layer", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 2) Knowledge Entropy / Hook Utilities
# ========================================================
def register_ke_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {i: {'pre_act': None} for i in range(len(layers))}
    hooks = []
    for i, layer in enumerate(layers):
        def hook_ffn_input(module, input, output, idx=i):
            activations[idx]['pre_act'] = input[0].detach()
        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))
    return hooks, activations

def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):
    ke_scores = {}
    for idx, buf in activations.items():
        pre_act = buf['pre_act']
        if pre_act is None:
            continue
        act = activation_fn(pre_act)
        probs = act / (act.sum(dim=1, keepdim=True) + eps)
        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()
        ke_scores[idx] = entropy.item()
        buf['pre_act'] = None
    return ke_scores

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

# ========================================================
# 3) Pruning Utilities with SkipFF (prune high‐KE)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_ke_layers(model, ke_scores, num_prune=4):
    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]
    for idx in prune_idxs:
        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()
        model.roberta.encoder.layer[idx].output = SkipFF()
    return prune_idxs

# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("pearsonr")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            # Regression head: output is [B, 1]
            pred = out.logits.view(-1).cpu().numpy()
            preds.extend(pred)
    return metric.compute(predictions=preds, references=labs)["pearsonr"]

# ========================================================
# 6) Training Stages (using KE instead of ER)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & KE Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=1  # num_labels=1 for regression!
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_ke_hooks(model)
    last_ke = None

    for epoch in range(6):
        ke_sums, ke_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level KE
            batch_ke = compute_batch_knowledge_entropy(activations)
            for idx, v in batch_ke.items():
                ke_sums[idx]   += v
                ke_counts[idx] += 1

        # epoch‐level KE
        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]
                    for idx in ke_sums if ke_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Knowledge Entropy:", epoch_ke)
        last_ke = epoch_ke

    pearson = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune STS-B Pearson: {pearson:.4f}")

    remove_hooks(hooks)
    return model, last_ke

def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):
    print("=== Stage 2: Prune (High‐KE) & Finetuning ===")
    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)
    print("Pruned layers (highest‐KE):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))
            out.loss.backward()
            opt.step()
            sched.step()
        pearson = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}")

    return model

def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        pearson = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}")

# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "stsb", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "stsb", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, ke_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt

# Provided KE values
ke_epochs = {
    1: {0: 3.1216782207117624, 1: 3.205926341373831, 2: 3.1831439436725515, 3: 3.1425496576890164, 4: 3.108628351930451,
        5: 3.0245226891879744, 6: 3.0113194900693085, 7: 2.96199728318481, 8: 2.9216685882032496, 9: 2.9011312645234386,
        10: 2.7949637216056007, 11: 2.8036840522405337},
    2: {0: 3.122385612946724, 1: 3.2053615688780255, 2: 3.1827276502431516, 3: 3.14758318588034, 4: 3.128111186047423,
        5: 3.0808556003597083, 6: 3.0710045684527953, 7: 3.012045979334018, 8: 2.9473394376014634, 9: 2.890920449032737,
        10: 2.706535710744632, 11: 2.694957201271959},
    3: {0: 3.1233545473785824, 1: 3.2060421793119303, 2: 3.180608749721246, 3: 3.149725641096749, 4: 3.1337831318792944,
        5: 3.087708413518022, 6: 3.0823230342175267, 7: 3.0231315847563978, 8: 2.9449381267908383, 9: 2.8736855148108513,
        10: 2.6799341229636413, 11: 2.6526868296929624},
    4: {0: 3.122849219697573, 1: 3.208289318190828, 2: 3.1786043458256836, 3: 3.152634641225547, 4: 3.140414782127518,
        5: 3.0983164837695294, 6: 3.087455921942402, 7: 3.0256239744486164, 8: 2.950759364434509, 9: 2.883653296548899,
        10: 2.686351388817205, 11: 2.6519999305130875},
    5: {0: 3.1224169508969832, 1: 3.2059260482416696, 2: 3.17878786678606, 3: 3.156504187365069, 4: 3.1431414029860196,
        5: 3.0961787657545403, 6: 3.0855182536951524, 7: 3.0326405927766844, 8: 2.9571537298353725, 9: 2.8834909653298877,
        10: 2.6865991667348252, 11: 2.644667492788259},
    6: {0: 3.122367330651953, 1: 3.20639664672512, 2: 3.179138040343644, 3: 3.157353766604491, 4: 3.1413243549755454,
        5: 3.0952708057965954, 6: 3.0864931734612986, 7: 3.033865492267635, 8: 2.9587372055637324, 9: 2.8831359905725726,
        10: 2.6869357043413524, 11: 2.6412745341140473}
}

# Plot
plt.figure(figsize=(10, 6))
for epoch, ke in ke_epochs.items():
    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12
    values = list(ke.values())
    plt.plot(layers, values, marker='o', label=f'Epoch {epoch}')

plt.xlabel('Layer Index', fontsize=16)
plt.ylabel('Knowledge Entropy', fontsize=16)
#plt.title('Knowledge Entropy vs Layer (fz=16)', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.show()
