In [None]:
!pip uninstall -y datasets
!pip install datasets==2.18.0
!pip install evaluate


Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
Collecting datasets==2.18.0
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.18.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.2.0-py3-none-any.whl (170 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow_hotfix-0.7-py3-none-any.whl (7.9 kB)
Installing collected packages: pyarrow-hotfix, fsspec, datasets
  Attempting uninstall: fsspec


In [None]:
# Prune Layers Based on Entropy Rate Using MNLI Dataset

# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)
                cos_squares.append(c2.item())
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================


def preprocess_function(examples, tok, max_length=128):
    return tok(examples['premise'],
               examples['hypothesis'],
               truncation=True,
               padding='max_length',
               max_length=max_length)



def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=3).to(device)
    model.gradient_checkpointing_enable()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    scaler = GradScaler()
    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx] += v
                er_counts[idx] += 1

        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MNLI Acc: {acc:.4f}")
    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)
    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    for epoch in range(3):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}")
    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True
    opt = torch.optim.Adam(
        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)
    scaler = GradScaler()
    for epoch in range(3):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "mnli", split="train[:5000]").shuffle(seed)
    dev_ds = load_dataset("glue", "mnli", split="validation_matched[:1000]")
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")


    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["premise","hypothesis","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["premise","hypothesis","idx"])\
                    .rename_column("label","labels")



    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)
    model, er_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    """
    we watch each pair of adjacent
    layers' output.dense activations.
    """
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    """
    For each adjacent layer‐pair idx,
    approximate the conditional entropy rate via
      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)
    We return only the cosine‐sum term; the additive constant is the
    same for all layers and can be dropped for pruning.
    """
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        # not enough history yet
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        # flatten across all non‐batch dims
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                )  # [1]
                cos_squares.append(c2.item())
            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er

        # shift history
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    # sort descending by ER → highest‐entropy layers first
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level ER
            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx]   += v
                er_counts[idx] += 1

        # epoch‐level ER
        epoch_er = {idx: er_sums[idx]/er_counts[idx]
                    for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune MRPC Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    # If you want to continue monitoring ER during LoRA, you can re-hook here.
    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load & preprocess MRPC subset
    train_ds = load_dataset("glue", "mrpc", split="train")\
               .shuffle(seed).select(range(1000))
    dev_ds   = load_dataset("glue", "mrpc", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True,
                         remove_columns=["sentence1","sentence2","idx"])\
                    .rename_column("label","labels")
    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                       batched=True,
                       remove_columns=["sentence1","sentence2","idx"])\
                  .rename_column("label","labels")

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, er_scores = full_finetuning(train_loader,
                                       dev_loader,
                                       device)
    model = prune_and_finetuning(model,
                                 train_loader,
                                 dev_loader,
                                 device,
                                 er_scores)
    lora_only_finetuning(model,
                         train_loader,
                         dev_loader,
                         device)

if __name__ == "__main__":
    main()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    """
    Exactly the same hooks as MIR: we watch each pair of adjacent
    layers' output.dense activations.
    """
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    """
    For each adjacent layer‐pair idx,
    approximate the conditional entropy rate via
      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)
    We return only the cosine‐sum term; the additive constant is the
    same for all layers and can be dropped for pruning.
    """
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        # not enough history yet
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        # flatten across all non‐batch dims
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                )  # [1]
                cos_squares.append(c2.item())
            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er

        # shift history
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    # sort descending by ER → highest‐entropy layers first
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level ER
            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx]   += v
                er_counts[idx] += 1

        # epoch‐level ER
        epoch_er = {idx: er_sums[idx]/er_counts[idx]
                    for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune SST2 Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] SST2 Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    # If you want to continue monitoring ER during LoRA, you can re-hook here.
    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] SST2 Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# load & preprocess SST-2 subset
    train_ds = load_dataset("glue", "sst2", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "sst2", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])
    dev = dev_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])



    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, er_scores = full_finetuning(train_loader,
                                       dev_loader,
                                       device)
    model = prune_and_finetuning(model,
                                 train_loader,
                                 dev_loader,
                                 device,
                                 er_scores)
    lora_only_finetuning(model,
                         train_loader,
                         dev_loader,
                         device)

if __name__ == "__main__":
    main()


In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ----------- Entropy Rate / Hook Utilities (Theorem 2) -----------
def register_er_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()
        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))
    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue
        B = X_curr.size(0)
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)
        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                )
                cos_squares.append(c2.item())
            er = sum(cos_squares) / (2 * (B - 1))
        er_scores[idx] = er
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None
    return er_scores

def estimate_entropy_rate(model, loader, device, max_batches=10):
    model.eval()
    hooks, activations = register_er_hooks(model)
    er_sum, er_count = {}, {}
    loader_iter = iter(loader)   # <-- FIXED LINE
    for _ in range(max_batches):
        try:
            batch = next(loader_iter)
        except StopIteration:
            break
        _ = model(input_ids=batch['input_ids'].to(device),
                  attention_mask=batch['attention_mask'].to(device))
        batch_er = compute_batch_entropy(activations)
        for idx, val in batch_er.items():
            er_sum[idx] = er_sum.get(idx, 0.0) + val
            er_count[idx] = er_count.get(idx, 0) + 1
    remove_hooks(hooks)
    er_avg = {k: (er_sum[k]/er_count[k] if er_count[k] > 0 else 0.0) for k in er_sum}
    return er_avg


# ----------- Pruning Utilities -----------
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ----------- Standard Preprocessing & Training -----------

def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence'], truncation=True, padding='max_length', max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

def finetune_model(model, train_loader, dev_loader, device, epochs):
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*epochs)
    scaler = GradScaler()
    for epoch in range(epochs):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Epoch {epoch+1}] SST-2 Acc: {acc:.4f}")
    return model

def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = load_dataset("glue", "sst2", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "sst2", split="validation")
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                         batched=True, remove_columns=["sentence"]).rename_column("label", "labels")
    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),
                     batched=True, remove_columns=["sentence"]).rename_column("label", "labels")

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    print("\n=== Stage 1: Full Fine-Tuning (No Pruning) ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=2).to(device)
    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)
    acc_full = evaluate_model(model, dev_loader, device)
    print(f"\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}")

    print("\n=== Stage 2: Entropy Rate Pruning (Remove 4 highest-ER layers) ===")
    dev_iter = iter(dev_loader)
    er_scores = estimate_entropy_rate(model, dev_loader, device, max_batches=10)
    print("Layer-wise Entropy Rate:", er_scores)
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print(f"Pruned layer indices: {prune_idxs}")

    print("\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===")
    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)
    acc_pruned = evaluate_model(model, dev_loader, device)
    print(f"\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}")

if __name__ == "__main__":
    main()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    """
    Exactly the same hooks as MIR: we watch each pair of adjacent
    layers' output.dense activations.
    """
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    """
    For each adjacent layer‐pair idx,
    approximate the conditional entropy rate via
      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)
    We return only the cosine‐sum term; the additive constant is the
    same for all layers and can be dropped for pruning.
    """
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        # not enough history yet
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        # flatten across all non‐batch dims
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                )  # [1]
                cos_squares.append(c2.item())
            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er

        # shift history
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    # sort descending by ER → highest‐entropy layers first
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level ER
            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx]   += v
                er_counts[idx] += 1

        # epoch‐level ER
        epoch_er = {idx: er_sums[idx]/er_counts[idx]
                    for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune CoLA Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] CoLA Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    # If you want to continue monitoring ER during LoRA, you can re-hook here.
    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] CoLA Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


      # load & preprocess CoLA subset
    train_ds = load_dataset("glue", "cola", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "cola", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train = train_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                         batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence", "idx"])
    dev = dev_ds.map(lambda ex: tokenizer(ex["sentence"], truncation=True, padding='max_length', max_length=64),
                     batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence", "idx"])



    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=64)
    train_loader = DataLoader(train, batch_size=8, shuffle=True,
                              collate_fn=collator)
    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,
                              collate_fn=collator)

    model, er_scores = full_finetuning(train_loader,
                                       dev_loader,
                                       device)
    model = prune_and_finetuning(model,
                                 train_loader,
                                 dev_loader,
                                 device,
                                 er_scores)
    lora_only_finetuning(model,
                         train_loader,
                         dev_loader,
                         device)

if __name__ == "__main__":
    main()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    """
    Exactly the same hooks as MIR: we watch each pair of adjacent
    layers' output.dense activations.
    """
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    """
    For each adjacent layer‐pair idx,
    approximate the conditional entropy rate via
      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)
    We return only the cosine‐sum term; the additive constant is the
    same for all layers and can be dropped for pruning.
    """
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        # not enough history yet
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        # flatten across all non‐batch dims
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                ) # [1]
                cos_squares.append(c2.item())
            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er

        # shift history
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    # sort descending by ER → highest‐entropy layers first
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level ER
            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx]   += v
                er_counts[idx] += 1

        # epoch‐level ER
        epoch_er = {idx: er_sums[idx]/er_counts[idx]
                    for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QNLI Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    # If you want to continue monitoring ER during LoRA, you can re-hook here.
    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================

def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess QNLI
    from datasets import load_dataset
    train_ds = load_dataset("glue", "qnli", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "qnli", split="validation")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    def preprocess(examples):
        return tokenizer(examples["question"],
                         examples["sentence"],
                         truncation=True,
                         padding='max_length',
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["question", "sentence", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["question", "sentence", "idx"])

    collator     = DataCollatorWithPadding(tokenizer,
                                           padding="max_length",
                                           max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    model, er_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)



if __name__ == "__main__":
    main()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    """
    Exactly the same hooks as MIR: we watch each pair of adjacent
    layers' output.dense activations.
    """
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    """
    For each adjacent layer‐pair idx,
    approximate the conditional entropy rate via
      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)
    We return only the cosine‐sum term; the additive constant is the
    same for all layers and can be dropped for pruning.
    """
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        # not enough history yet
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        # flatten across all non‐batch dims
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = []
            for i in range(1, B):
                c2 = F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                )  # [1]
                cos_squares.append(c2.item())
            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er

        # shift history
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    # sort descending by ER → highest‐entropy layers first
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level ER
            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx]   += v
                er_counts[idx] += 1

        # epoch‐level ER
        epoch_er = {idx: er_sums[idx]/er_counts[idx]
                    for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune QQP Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    # If you want to continue monitoring ER during LoRA, you can re-hook here.
    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================

def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess QQP
    from datasets import load_dataset
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train_ds = load_dataset("glue", "qqp", split="train").shuffle(seed).select(range(5000))
    dev_ds   = load_dataset("glue", "qqp", split="validation")

    def preprocess(examples):
        return tokenizer(examples["question1"],
                         examples["question2"],
                         truncation=True,
                         padding="max_length",
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["question1", "question2", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["question1", "question2", "idx"])

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    model, er_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)


if __name__ == "__main__":
    main()


In [None]:
# ========================================================
# 1) Standard imports and warning suppression
# ========================================================
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
import evaluate
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# ========================================================
# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    """
    Exactly the same hooks as MIR: we watch each pair of adjacent
    layers' output.dense activations.
    """
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()

        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))

    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()



def compute_batch_entropy(activations, sigma2=1.0):
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']

        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        if X_prev.size(0) != X_curr.size(0) or Y_prev.size(0) != Y_curr.size(0):
            # Skip batch if sizes mismatch
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue

        B = X_curr.size(0)
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)

        if B < 2:
            er = 0.0
        else:
            cos_squares = [
                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item()
                for i in range(1, B)
            ]
            er = sum(cos_squares) / (2 * (B - 1))

        er_scores[idx] = er

        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None

    return er_scores


# ========================================================
# 3) Pruning Utilities with SkipFF (prune low‑ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    # sort descending by ER → highest‐entropy layers first
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs


# ========================================================
# 4) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras


# ========================================================
# 5) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]


# ========================================================
# 6) Training Stages (using ER instead of MIR)
# ========================================================
def full_finetuning(train_loader, dev_loader, device):
    print("=== Stage 1: Full Finetuning & ER Estimation ===")
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=2
    ).to(device)
    model.gradient_checkpointing_enable()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None

    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            # batch‐level ER
            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx]   += v
                er_counts[idx] += 1

        # epoch‐level ER
        epoch_er = {idx: er_sums[idx]/er_counts[idx]
                    for idx in er_sums if er_counts[idx] > 0}
        print(f"[Epoch {epoch+1}] approx Entropy Rate:", epoch_er)
        last_er = epoch_er

    acc = evaluate_model(model, dev_loader, device)
    print(f"-> Full Finetune RTE Acc: {acc:.4f}")

    remove_hooks(hooks)
    return model, last_er


def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    print("=== Stage 2: Prune (High‐ER) & Finetuning ===")
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers (highest‐ER):", prune_idxs)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*3)

    for epoch in range(5):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            out = model(input_ids=b['input_ids'].to(device),
                        attention_mask=b['attention_mask'].to(device),
                        labels=b['labels'].to(device))
            out.loss.backward()
            opt.step()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}")

    return model


def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):
    # (unchanged LoRA stage)
    print("=== Stage 3: LoRA Finetuning ===")
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model, r, alpha)
    for p in model.roberta.parameters(): p.requires_grad=False
    for p in model.classifier.parameters(): p.requires_grad=True
    for l in loras.values():
        l.A.requires_grad=True
        l.B.requires_grad=True

    opt   = torch.optim.Adam(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*6)
    scaler = GradScaler()

    # If you want to continue monitoring ER during LoRA, you can re-hook here.
    for epoch in range(6):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}")


# ========================================================
# 7) Main Entrypoint
# ========================================================

def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load & preprocess RTE
    from datasets import load_dataset
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    train_ds = load_dataset("glue", "rte", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "rte", split="validation")

    def preprocess(examples):
        return tokenizer(examples["sentence1"],
                         examples["sentence2"],
                         truncation=True,
                         padding="max_length",
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence1", "sentence2", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence1", "sentence2", "idx"])

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    model, er_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)


if __name__ == "__main__":
    main()


In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import defaultdict
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import evaluate
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ========================================================
# 1) Entropy‐Rate / Hook Utilities (implements Theorem 2)
# ========================================================
def register_er_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()
        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))
    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']
        if None in (X_prev, Y_prev, X_curr, Y_curr):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue
        if X_prev.size(0) != X_curr.size(0) or Y_prev.size(0) != Y_curr.size(0):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue
        B = X_curr.size(0)
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)
        if B < 2:
            er = 0.0
        else:
            cos_squares = [
                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item()
                for i in range(1, B)
            ]
            er = sum(cos_squares) / (2 * (B - 1))
        er_scores[idx] = er
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None
    return er_scores

def estimate_entropy_rate(model, loader, device, max_batches=10):
    hooks, activations = register_er_hooks(model)
    model.eval()
    er_sums = {idx: 0.0 for idx in activations}
    er_counts = {idx: 0 for idx in activations}
    num_batches = 0
    for batch in loader:
        with torch.no_grad():
            ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            _ = model(input_ids=ids, attention_mask=mask)
        batch_er = compute_batch_entropy(activations)
        for idx, v in batch_er.items():
            if v != 0.0:
                er_sums[idx] += v
                er_counts[idx] += 1
        num_batches += 1
        if num_batches >= max_batches:
            break
    remove_hooks(hooks)
    er_scores = {idx: er_sums[idx] / er_counts[idx] for idx in er_sums if er_counts[idx] > 0}
    return er_scores

# ========================================================
# 2) Pruning Utilities with SkipFF (prune high-ER)
# ========================================================
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [
        idx+1
        for idx, _ in sorted_layers[:num_prune]
        if idx+1 < len(model.roberta.encoder.layer)
    ]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ========================================================
# 3) LoRA Modules (unchanged)
# ========================================================
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ========================================================
# 4) Data + Eval Helpers
# ========================================================
def preprocess_function(examples, tok, max_length=64):
    return tok(examples['sentence1'],
               examples['sentence2'],
               truncation=True,
               padding='max_length',
               max_length=max_length)

def evaluate_model(model, dl, device):
    model.eval()
    metric = evaluate.load("accuracy")
    preds, labs = [], []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(device)
            mask = b['attention_mask'].to(device)
            labs.extend(b['labels'].cpu().numpy())
            out = model(input_ids=ids, attention_mask=mask)
            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())
    return metric.compute(predictions=preds, references=labs)["accuracy"]

# ========================================================
# 5) Standard fine-tuning (your version, used everywhere)
# ========================================================
def finetune_model(model, train_loader, dev_loader, device, epochs):
    model.train()
    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(opt,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*epochs)
    scaler = GradScaler()
    for epoch in range(epochs):
        model.train()
        for b in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(input_ids=b['input_ids'].to(device),
                            attention_mask=b['attention_mask'].to(device),
                            labels=b['labels'].to(device))
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()
        acc = evaluate_model(model, dev_loader, device)
        print(f"[Epoch {epoch+1}] RTE Acc: {acc:.4f}")
    return model

# ========================================================
# 6) Main Entrypoint
# ========================================================
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "rte", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "rte", split="validation")

    def preprocess(examples):
        return tokenizer(examples["sentence1"],
                         examples["sentence2"],
                         truncation=True,
                         padding="max_length",
                         max_length=128)

    train = train_ds.map(preprocess, batched=True)\
                    .rename_column("label", "labels")\
                    .remove_columns(["sentence1", "sentence2", "idx"])
    dev = dev_ds.map(preprocess, batched=True)\
                .rename_column("label", "labels")\
                .remove_columns(["sentence1", "sentence2", "idx"])

    collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)
    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)

    # --- Stage 1: Full Fine-Tuning ---
    print("\n=== Stage 1: Full Fine-Tuning (No Pruning) ===")
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=2).to(device)
    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)
    acc_full = evaluate_model(model, dev_loader, device)
    print(f"\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}")

    # --- Stage 2: ER-Based Pruning ---
    print("\n=== Stage 2: ER-Based Pruning (Remove 4 highest-ER layers) ===")
    er_scores = estimate_entropy_rate(model, dev_loader, device, max_batches=10)
    print("Estimated ER per layer pair:", er_scores)
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print(f"Pruning 4 layers with highest ER: {prune_idxs}")

    # --- Stage 3: Fine-tune pruned model ---
    print("\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===")
    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)
    acc_pruned = evaluate_model(model, dev_loader, device)
    print(f"\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}")

if __name__ == "__main__":
    main()


In [None]:
import numpy as np
import random
import math
import warnings

# Monkey‑patch numpy.array to ignore the copy argument (workaround for NumPy 2.0)
_np_array = np.array
def _patched_array(obj, *args, copy=False, **kwargs):
    return _np_array(obj, *args, **kwargs)
np.array = _patched_array

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
import evaluate

from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)
from datasets import load_dataset

from collections import defaultdict

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ─── 1) Entropy‐Rate Hooks (Theorem 2) ─────────────────────────────────────
def register_er_hooks(model):
    layers = model.roberta.encoder.layer
    activations = {
        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}
        for i in range(len(layers)-1)
    }
    hooks = []
    for i in range(len(layers)-1):
        def hook_x(module, inp, out, idx=i):
            activations[idx]['curr_X'] = out.detach()
        def hook_y(module, inp, out, idx=i):
            activations[idx]['curr_Y'] = out.detach()
        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))
        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))
    return hooks, activations

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

def compute_batch_entropy(activations, sigma2=1.0):
    er_scores = {}
    for idx, buf in activations.items():
        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']
        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']
        # need valid previous + current and same batch size
        if None in (X_prev, Y_prev, X_curr, Y_curr) or \
           X_prev.size(0) != X_curr.size(0):
            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
            buf['curr_X'], buf['curr_Y'] = None, None
            continue
        B = X_curr.size(0)
        dX = (X_curr - X_prev).view(B, -1)
        dY = (Y_curr - Y_prev).view(B, -1)
        if B < 2:
            er = 0.0
        else:
            cos_squares = [
                F.cosine_similarity(
                    dY[i].unsqueeze(0),
                    dX[i].unsqueeze(0),
                    dim=1, eps=1e-8
                ).item()
                for i in range(1, B)
            ]
            er = sum(cos_squares) / (2 * (B - 1))
        er_scores[idx] = er
        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr
        buf['curr_X'], buf['curr_Y'] = None, None
    return er_scores

# ─── 2) Pruning Utilities ─────────────────────────────────────────────────
class SkipFF(nn.Module):
    def forward(self, hidden_states, input_tensor=None):
        return input_tensor

def prune_er_layers(model, er_scores, num_prune=4):
    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)
    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune]
                  if idx+1 < len(model.roberta.encoder.layer)]
    for idx in prune_idxs:
        layer = model.roberta.encoder.layer[idx]
        layer.intermediate.dense = nn.Identity()
        layer.output = SkipFF()
    return prune_idxs

# ─── 3) LoRA Modules ──────────────────────────────────────────────────────
class LoRA(nn.Module):
    def __init__(self, W0, r=2, alpha=1.0):
        super().__init__()
        self.register_buffer("W0", W0.clone().detach())
        L, M = W0.shape
        self.B = nn.Parameter(torch.randn(L, r) * 0.01)
        self.A = nn.Parameter(torch.zeros(r, M))
        self.scaling = alpha / r
    def forward(self):
        return self.W0 + self.scaling * (self.B @ self.A)

def apply_lora_to_all_layers(model, r=2, alpha=1.0):
    loras = {}
    for idx, layer in enumerate(model.roberta.encoder.layer):
        if not hasattr(layer.output, 'dense'):
            continue
        W0 = layer.output.dense.weight.data
        lora = LoRA(W0, r, alpha).to(W0.device)
        def fwd(x, layer=layer, lora=lora):
            return F.linear(x, lora(), layer.output.dense.bias)
        layer.output.dense.forward = fwd
        loras[idx] = lora
    return loras

# ─── 4) STS-B Evaluation with Flattened References ─────────────────────────
def evaluate_stsb(model, dataloader, device):
    model.eval()
    metric = evaluate.load("glue", "stsb")
    preds, refs = [], []
    with torch.no_grad():
        for batch in dataloader:
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            # flatten predictions
            p = out.logits.squeeze(-1).cpu().tolist()
            preds.extend(p if isinstance(p, list) else [p])
            # flatten references (handle [[5.0], [4.75], ...])
            r = batch["labels"].cpu().tolist()
            # r might be list-of-lists or list-of-floats
            for x in r:
                if isinstance(x, (list, tuple, np.ndarray)):
                    refs.append(float(x[0]))
                else:
                    refs.append(float(x))
    return metric.compute(predictions=preds, references=refs)

# ─── 5) Training Stages ────────────────────────────────────────────────────
def full_finetuning(train_loader, dev_loader, device):
    model = RobertaForSequenceClassification.from_pretrained(
        "roberta-base", num_labels=1
    ).to(device)
    model.gradient_checkpointing_enable()
    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    hooks, activations = register_er_hooks(model)
    last_er = None
    for epoch in range(6):
        er_sums, er_counts = defaultdict(float), defaultdict(int)
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

            batch_er = compute_batch_entropy(activations)
            for idx, v in batch_er.items():
                er_sums[idx] += v
                er_counts[idx] += 1

        epoch_er = {
            idx: er_sums[idx] / er_counts[idx]
            for idx in er_sums if er_counts[idx] > 0
        }
        print(f"[Epoch {epoch+1}] ER:", epoch_er)
        last_er = epoch_er

    metrics = evaluate_stsb(model, dev_loader, device)
    print(f"STS‑B Pearson: {metrics['pearson']:.4f}, Spearman: {metrics['spearmanr']:.4f}")
    remove_hooks(hooks)
    return model, last_er

def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):
    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)
    print("Pruned layers:", prune_idxs)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3
    )
    for epoch in range(5):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            out.loss.backward()
            opt.step()
            sched.step()

        metrics = evaluate_stsb(model, dev_loader, device)
        print(f"[Prune Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}")
    return model

def lora_only_finetuning(model, train_loader, dev_loader, device):
    torch.cuda.empty_cache()
    loras = apply_lora_to_all_layers(model)
    for p in model.roberta.parameters(): p.requires_grad = False
    for p in model.classifier.parameters(): p.requires_grad = True
    for l in loras.values():
        l.A.requires_grad = True
        l.B.requires_grad = True

    opt = torch.optim.AdamW(
        list(model.classifier.parameters())
        + [p for l in loras.values() for p in (l.A, l.B)],
        lr=2e-5
    )
    sched = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6
    )
    scaler = GradScaler()

    for epoch in range(6):
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            with autocast():
                out = model(
                    input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device),
                    labels=batch["labels"].to(device),
                )
                scaler.scale(out.loss).backward()
            scaler.step(opt)
            scaler.update()
            sched.step()

        metrics = evaluate_stsb(model, dev_loader, device)
        print(f"[LoRA Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}")

# ─── 6) Main Entrypoint ────────────────────────────────────────────────────
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    train_ds = load_dataset("glue", "stsb", split="train").shuffle(seed)
    dev_ds   = load_dataset("glue", "stsb", split="validation")

    def preprocess(ex):
        return tokenizer(
            ex["sentence1"], ex["sentence2"],
            truncation=True, padding="max_length", max_length=128
        )

    train_ds = train_ds.map(preprocess, batched=True)
    dev_ds   = dev_ds.map(preprocess, batched=True)

    # Cast labels to flat float32
    train_ds = train_ds.map(lambda x: {"labels": float(x["label"])}, batched=False)
    dev_ds   = dev_ds.map(lambda x: {"labels": float(x["label"])}, batched=False)

    train_ds = train_ds.remove_columns(["sentence1", "sentence2", "label", "idx"])
    dev_ds   = dev_ds.remove_columns(["sentence1", "sentence2", "label", "idx"])

    train_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    dev_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    collator     = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=128)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  collate_fn=collator)
    dev_loader   = DataLoader(dev_ds,   batch_size=16, shuffle=False, collate_fn=collator)

    model, er_scores = full_finetuning(train_loader, dev_loader, device)
    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)
    lora_only_finetuning(model, train_loader, dev_loader, device)

if __name__ == "__main__":
    main()
