In [1]:
import torch, numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from datasets import load_dataset
from scipy.optimize import linear_sum_assignment

In [2]:
import torch.nn as nn


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [4]:
def extract_ffn_features_single_pass(model, texts, max_tokens=4000):
    """
    Extract FFN pre-activation features (c_fc outputs)
    for all layers from a SINGLE forward pass,
    excluding padding tokens.
    """
    model.eval()
    features = {i: [] for i in range(len(model.transformer.h))}
    hooks = []

    def make_hook(layer_idx):
        def hook(module, inp, out):
            # out: [batch, seq, d_ff]
            features[layer_idx].append(out.detach().cpu())
        return hook

    # Register hooks ONLY on c_fc
    for i, block in enumerate(model.transformer.h):
        hooks.append(
            block.mlp.c_fc.register_forward_hook(make_hook(i))
        )

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    ).to(device)

    attention_mask = enc["attention_mask"]  # [B, T]

    with torch.no_grad():
        model(**enc)

    for h in hooks:
        h.remove()

    X = {}
    for i in features:
        # [B, T, d]
        # acts = torch.cat(features[i], dim=0)

        # # mask out padding tokens
        # mask = attention_mask.bool().unsqueeze(-1)  # [B, T, 1]
        # acts = acts[mask.expand_as(acts)].view(-1, acts.shape[-1])

        
        acts = torch.cat(features[i], dim=0)  # acts is on CPU

        mask = attention_mask.bool().unsqueeze(-1).cpu()  # <-- FIX

        acts = acts[mask.expand_as(acts)].view(-1, acts.shape[-1])

        if acts.shape[0] > max_tokens:
            acts = acts[:max_tokens]

        X[i] = acts.numpy()

    return X


In [5]:
from scipy.optimize import linear_sum_assignment
import numpy as np

def compute_adjacent_permutations(feature_dict, window_layers):
    """
    Compute permutations ONLY between anchor and adjacent layers
    inside a window.
    """
    anchor = window_layers[0]
    perms = {}

    X_anchor = feature_dict[anchor]

    for layer in window_layers[1:]:
        X_other = feature_dict[layer]

        # Pearson correlation (Eq. 1 in paper)
        C = np.corrcoef(X_anchor, X_other, rowvar=False)
        d = X_anchor.shape[1]
        C = C[:d, d:]

        _, col_ind = linear_sum_assignment(-C)
        perms[layer] = col_ind

    return perms


In [6]:
def sliding_windows(n_layers, k):
    """
    Generate sliding windows of k adjacent layers.
    """
    return [list(range(i, i + k)) for i in range(n_layers - k + 1)]


In [7]:
def permute_ffn(model, layer, perm):
    block = model.transformer.h[layer].mlp
    perm = torch.tensor(perm, dtype=torch.long, device=device)

    # GPT-2 weight shapes are transposed in your setup:
    # c_fc:   [768, 3072]
    # c_proj: [3072, 768]
    with torch.no_grad():
        block.c_fc.weight[:] = block.c_fc.weight[:, perm]
        block.c_fc.bias[:]   = block.c_fc.bias[perm]
        block.c_proj.weight[:] = block.c_proj.weight[perm, :]

In [8]:
import torch
import torch.nn as nn

def merge_k_layers(model, layers, perms):
    """
    Paper-faithful FFN merging with TRUE weight tying.
    - No in-place permutation
    - Logical permutation inside averaging
    - Shared nn.Parameter objects
    """

    anchor = layers[0]
    anchor_mlp = model.transformer.h[anchor].mlp

    device = anchor_mlp.c_fc.weight.device
    k = len(layers)

    # Start from anchor weights
    W_in  = anchor_mlp.c_fc.weight.data.clone()
    b_in  = anchor_mlp.c_fc.bias.data.clone()
    W_out = anchor_mlp.c_proj.weight.data.clone()
    b_out = anchor_mlp.c_proj.bias.data.clone()

    # Accumulate aligned weights (NO mutation)
    for layer in layers[1:]:
        mlp = model.transformer.h[layer].mlp
        perm = torch.tensor(perms[layer], device=device)

        # Logical permutation (as in Eq. 3â€“6)
        W_in  += mlp.c_fc.weight[:, perm]
        b_in  += mlp.c_fc.bias[perm]
        W_out += mlp.c_proj.weight[perm, :]
        b_out += mlp.c_proj.bias

    # Average
    W_in  /= k
    b_in  /= k
    W_out /= k
    b_out /= k

    # Create SHARED parameters (true tying)
    shared_c_fc_weight   = nn.Parameter(W_in)
    shared_c_fc_bias     = nn.Parameter(b_in)
    shared_c_proj_weight = nn.Parameter(W_out)
    shared_c_proj_bias   = nn.Parameter(b_out)

    # Tie all layers in the window
    for layer in layers:
        mlp = model.transformer.h[layer].mlp

        mlp.c_fc.weight = shared_c_fc_weight
        mlp.c_fc.bias   = shared_c_fc_bias

        mlp.c_proj.weight = shared_c_proj_weight
        mlp.c_proj.bias   = shared_c_proj_bias


In [9]:
import math
import torch

def compute_perplexity(model, texts, batch_size=4, max_length=128):
    model.eval()
    total_nll = 0.0
    total_tokens = 0

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]

        enc = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, :-1, :]
            labels = input_ids[:, 1:]

        # mask padding tokens
        mask = attention_mask[:, 1:].bool()

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(
            dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)

        total_nll += -(token_log_probs[mask]).sum().item()
        total_tokens += mask.sum().item()

    return math.exp(total_nll / total_tokens)

# # merged_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
# # merge_k_layers(merged_model, window, perms)

# merged_ppl = compute_perplexity(model, texts)

# print(f"Merged (window={window}) PPL: {merged_ppl:.2f}")


In [10]:
dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="validation"
)
texts = [t for t in dataset["text"] if len(t.strip()) > 0][:200]
base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
base_ppl = compute_perplexity(base_model, texts)

print(f"Baseline GPT-2 PPL: {base_ppl:.2f}")

Baseline GPT-2 PPL: 46.98


In [11]:
# print("Before recovery PPL:",
#           compute_perplexity(model, texts))   

In [12]:
def print_num_layers(model):
    num_layers = len(model.transformer.h)
    print(f"Number of transformer layers: {num_layers}")
print_num_layers(base_model)

Number of transformer layers: 12


In [13]:
train_dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="train[:2%]"
)

train_texts = [t for t in train_dataset["text"] if len(t.strip()) > 0]


In [14]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    enc = tokenizer(
        batch,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    )
    return enc

train_loader = DataLoader(
    train_texts,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)


In [15]:
import torch.nn.functional as F
from torch.optim import AdamW

def recovery_finetune(
    model,
    train_loader,
    steps=1000,
    lr=1e-5,
    weight_decay=0.01,
    device="cuda"
):
    model.train()

    optimizer = AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    step = 0
    running_loss = 0.0

    for batch in train_loader:
        if step >= steps:
            break

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(input_ids)
        logits = outputs.logits[:, :-1, :]
        labels = input_ids[:, 1:]

        mask = attention_mask[:, 1:].bool()

        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(
            dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)

        loss = -(token_log_probs[mask]).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        step += 1

        if step % 100 == 0:
            print(f"[Recovery] step {step} | loss {running_loss / 100:.4f}")
            running_loss = 0.0

    model.eval()


In [16]:
base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

# 1) Extract features ONCE (same forward pass)
feature_mats = extract_ffn_features_single_pass(
    base_model, texts
)

In [24]:


model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
print("Before merging PPL:",
          compute_perplexity(model, texts))

Before merging PPL: 46.97976365832357


In [17]:
best_ppl = float("inf")
best_window = None
best_model = None   # <-- store MODEL OBJECT, not state_dict

k = 4
windows = sliding_windows(len(base_model.transformer.h), k)

for window in windows:
    model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

    perms = compute_adjacent_permutations(feature_mats, window)
    merge_k_layers(model, window, perms)

    ppl = compute_perplexity(model, texts)
    print(f"Window {window} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = window
        best_model = model   # âœ… keep tied model alive


Window [0, 1, 2, 3] | PPL 95623.84
Window [1, 2, 3, 4] | PPL 3289.73
Window [2, 3, 4, 5] | PPL 4249.91
Window [3, 4, 5, 6] | PPL 963.31
Window [4, 5, 6, 7] | PPL 1823.28
Window [5, 6, 7, 8] | PPL 306.99
Window [6, 7, 8, 9] | PPL 212.60
Window [7, 8, 9, 10] | PPL 175.50
Window [8, 9, 10, 11] | PPL 352.87


In [18]:
l0 = best_model.transformer.h[best_window[0]].mlp
l1 = best_model.transformer.h[best_window[1]].mlp

print("Weights tied?",
      l0.c_fc.weight is l1.c_fc.weight,
      l0.c_proj.weight is l1.c_proj.weight)


Weights tied? True True


In [19]:
print("Before recovery PPL:",
      compute_perplexity(best_model, texts))

recovery_finetune(
    best_model,
    train_loader,
    steps=10,
    lr=1e-5,
    device=device
)

print("After recovery PPL:",
      compute_perplexity(best_model, texts))


Before recovery PPL: 175.4980100246964
After recovery PPL: 124.45890570165714


In [23]:
# ---- TEST EVALUATION ----
test_dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="test"
)

test_texts = [t for t in test_dataset["text"] if len(t.strip()) > 0][:500]

val_ppl = compute_perplexity(best_model, texts)
print(f"Validation PPL after recovery: {val_ppl:.2f}")

test_ppl = compute_perplexity(best_model, test_texts)
print(f"TEST PPL after recovery: {test_ppl:.2f}")

Validation PPL after recovery: 124.46
TEST PPL after recovery: 169.64


In [20]:
l0 = best_model.transformer.h[best_window[0]].mlp
l1 = best_model.transformer.h[best_window[1]].mlp
# print(l0.c_fc.weight)
# print(l1.c_fc.weight)
print(l0.c_fc.weight is l1.c_fc.weight)   # MUST be True
print(l0.c_proj.weight is l1.c_proj.weight)  # MUST be True

True
True


In [21]:
def count_unique_params(model):
    seen = set()
    total = 0
    for p in model.parameters():
        if id(p) not in seen:
            seen.add(id(p))
            total += p.numel()
    return total

def count_params_naive(model):
    return sum(p.numel() for p in model.parameters())

print("Original GPT-2")
print(" Naive :", count_params_naive(base_model))
print(" Unique:", count_unique_params(base_model))

print("\nMerged (tied) model")
print(" Naive :", count_params_naive(best_model))
print(" Unique:", count_unique_params(best_model))


Original GPT-2
 Naive : 124439808
 Unique: 124439808

Merged (tied) model
 Naive : 110272512
 Unique: 110272512


In [22]:
prompt = "India will become global leader in AI because 1. it has"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    sample_out = best_model.generate(**inputs, max_length=40)
print("\nGenerated output:")
print(tokenizer.decode(sample_out[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Generated output:
India will become global leader in AI because 1. it has the ability to be the world AI leader. 2. it has the ability to be the world AI leader. 3. it has the ability


In [3]:
# ============================================================
# Soft Shared FFN (Wc + Low-Rank Delta) for GPT-2
# No averaging, no MoE, no information loss
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset
from scipy.optimize import linear_sum_assignment
from torch.optim import AdamW
from torch.utils.data import DataLoader

# -------------------------
# Config
# -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 8
WINDOW_K = 4
LAMBDA = 1e-4
MAX_TOKENS = 4000
EVAL_TEXTS = 200

# -------------------------
# Tokenizer
# -------------------------
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# ============================================================
# Low-rank FFN module
# ============================================================

class LowRankLinear(nn.Module):
    """
    GPT-2 compatible low-rank linear:
    weight layout = [in_dim, out_dim]
    """
    def __init__(self, Wc, bias, rank):
        super().__init__()

        self.Wc = Wc          # shared base weight
        self.bias = bias

        in_dim, out_dim = Wc.shape   # ðŸ”‘ GPT-2 layout

        device = Wc.device

        # Low-rank delta in SAME layout as Wc
        self.A = nn.Parameter(torch.zeros(in_dim, rank, device=device))
        self.B = nn.Parameter(torch.zeros(rank, out_dim, device=device))

    def forward(self, x):
        W_eff = self.Wc + self.A @ self.B   # shape [in_dim, out_dim]
        return F.linear(x, W_eff.T, self.bias)


    

# ============================================================
# Feature extraction (c_fc activations)
# ============================================================

def extract_ffn_features_single_pass(model, texts):
    model.eval()
    features = {i: [] for i in range(len(model.transformer.h))}
    hooks = []

    def make_hook(i):
        def hook(_, __, out):
            features[i].append(out.detach().cpu())
        return hook

    for i, block in enumerate(model.transformer.h):
        hooks.append(block.mlp.c_fc.register_forward_hook(make_hook(i)))

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    ).to(DEVICE)

    with torch.no_grad():
        model(**enc)

    for h in hooks:
        h.remove()

    X = {}
    for i in features:
        acts = torch.cat(features[i], dim=0)
        acts = acts.reshape(-1, acts.shape[-1])
        X[i] = acts[:MAX_TOKENS].numpy()

    return X

# ============================================================
# Hungarian alignment
# ============================================================

def compute_adjacent_permutations(feature_dict, window):
    anchor = window[0]
    perms = {}
    Xa = feature_dict[anchor]

    for layer in window[1:]:
        Xb = feature_dict[layer]
        C = np.corrcoef(Xa, Xb, rowvar=False)
        d = Xa.shape[1]
        C = C[:d, d:]
        _, col = linear_sum_assignment(-C)
        perms[layer] = col

    return perms

def sliding_windows(n, k):
    return [list(range(i, i + k)) for i in range(n - k + 1)]

# ============================================================
# Replace FFNs with soft shared base + low-rank delta
# ============================================================

def soft_share_ffn_lowrank(model, layers, perms, rank=RANK):
    anchor = layers[0]
    anchor_mlp = model.transformer.h[anchor].mlp

    device = anchor_mlp.c_fc.weight.device

    Wc_fc = nn.Parameter(anchor_mlp.c_fc.weight.data.clone().to(device))
    bc_fc = nn.Parameter(anchor_mlp.c_fc.bias.data.clone().to(device))
    Wc_proj = nn.Parameter(anchor_mlp.c_proj.weight.data.clone().to(device))
    bc_proj = nn.Parameter(anchor_mlp.c_proj.bias.data.clone().to(device))


    for layer in layers:
        mlp = model.transformer.h[layer].mlp
        perm = perms.get(layer, None)

        # c_fc
        mlp.c_fc = LowRankLinear(
        Wc=Wc_fc,
        bias=bc_fc,
        rank=rank
        )

        mlp.c_proj = LowRankLinear(
        Wc=Wc_proj,
        bias=bc_proj,
        rank=rank
        )

# ============================================================
# Loss helpers
# ============================================================

def lm_loss(model, batch):
    input_ids = batch["input_ids"].to(DEVICE)
    attention_mask = batch["attention_mask"].to(DEVICE)

    outputs = model(input_ids)
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:]

    mask = attention_mask[:, 1:].bool()
    log_probs = F.log_softmax(logits, dim=-1)
    token_log_probs = log_probs.gather(
        dim=-1, index=labels.unsqueeze(-1)
    ).squeeze(-1)

    return -(token_log_probs[mask]).mean()

def consensus_loss(model, layers, lam=LAMBDA):
    loss = 0.0
    for l in layers:
        mlp = model.transformer.h[l].mlp
        loss += torch.norm(mlp.c_fc.A @ mlp.c_fc.B, p="fro")**2
        loss += torch.norm(mlp.c_proj.A @ mlp.c_proj.B, p="fro")**2
    return lam * loss

# ============================================================
# Perplexity
# ============================================================

def compute_perplexity(model, texts, batch_size=4):
    model.eval()
    total_nll = 0
    total_tokens = 0

    for i in range(0, len(texts), batch_size):
        enc = tokenizer(
            texts[i:i + batch_size],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(DEVICE)

        with torch.no_grad():
            logits = model(enc["input_ids"]).logits[:, :-1, :]
            labels = enc["input_ids"][:, 1:]
            mask = enc["attention_mask"][:, 1:].bool()

            log_probs = F.log_softmax(logits, dim=-1)
            lp = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)

        total_nll += -(lp[mask]).sum().item()
        total_tokens += mask.sum().item()

    return math.exp(total_nll / total_tokens)

# ============================================================
# Main
# ============================================================

print("Loading data...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
texts = [t for t in dataset["text"] if len(t.strip()) > 0][:EVAL_TEXTS]

print("Loading base model...")
base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)

print("Extracting FFN features...")
features = extract_ffn_features_single_pass(base_model, texts)

windows = sliding_windows(len(base_model.transformer.h), WINDOW_K)

best_ppl = float("inf")
best_window = None

print("\nSearching best window...")
for w in windows:
    test_model = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)
    perms = compute_adjacent_permutations(features, w)
    soft_share_ffn_lowrank(test_model, w, perms)

    ppl = compute_perplexity(test_model, texts)
    print(f"Window {w} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = w

print("\nBest window:", best_window)

# ============================================================
# Fine-tuning
# ============================================================

train_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:2%]")
train_texts = [t for t in train_ds["text"] if len(t.strip()) > 0]

def collate_fn(batch):
    return tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128)

train_loader = DataLoader(train_texts, batch_size=4, shuffle=True, collate_fn=collate_fn)

model = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)
perms = compute_adjacent_permutations(features, best_window)
soft_share_ffn_lowrank(model, best_window, perms)

optimizer = AdamW(model.parameters(), lr=1e-5)

print("\nStarting fine-tuning...")
model.train()
for step, batch in enumerate(train_loader):
    if step >= 300:
        break

    loss = lm_loss(model, batch)
    loss += consensus_loss(model, best_window)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % 50 == 0:
        print(f"Step {step} | Loss {loss.item():.4f}")

# ============================================================
# Final evaluation
# ============================================================

val_ppl = compute_perplexity(model, texts)
print(f"\nFinal Validation PPL: {val_ppl:.2f}")

prompt = "India will become global leader in AI because"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    out = model.generate(**inputs, max_length=40)

print("\nGenerated text:")
print(tokenizer.decode(out[0], skip_special_tokens=True))


Loading data...
Loading base model...
Extracting FFN features...

Searching best window...
Window [0, 1, 2, 3] | PPL 3260.79
Window [1, 2, 3, 4] | PPL 395.99
Window [2, 3, 4, 5] | PPL 128.06
Window [3, 4, 5, 6] | PPL 102.01
Window [4, 5, 6, 7] | PPL 92.25
Window [5, 6, 7, 8] | PPL 91.09
Window [6, 7, 8, 9] | PPL 92.88
Window [7, 8, 9, 10] | PPL 98.46
Window [8, 9, 10, 11] | PPL 350.21

Best window: [5, 6, 7, 8]

Starting fine-tuning...
Step 0 | Loss 6.0397
Step 50 | Loss 4.5996
Step 100 | Loss 4.4414


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Final Validation PPL: 50.86

Generated text:
India will become global leader in AI because of the fact of its ability to communicate with the human.

The AI is a global leader in the world's AI. It is the most advanced AI


In [5]:
test_dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="test"
)

test_texts = [t for t in test_dataset["text"] if len(t.strip()) > 0][:500]
test_ppl = compute_perplexity(model, test_texts)
print(f"TEST PPL after recovery: {test_ppl:.2f}")

TEST PPL after recovery: 68.58
