In [None]:
# ============================================================
# Soft Shared FFN (Wc + Low-Rank Delta) for GPT-2
# No averaging, no MoE, no information loss
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset
from scipy.optimize import linear_sum_assignment
from torch.optim import AdamW
from torch.utils.data import DataLoader

# -------------------------
# Config
# -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 8
WINDOW_K = 4
LAMBDA = 1e-4
MAX_TOKENS = 4000
EVAL_TEXTS = 200

# -------------------------
# Tokenizer
# -------------------------
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# ============================================================
# Low-rank FFN module
# ============================================================

class LowRankLinear(nn.Module):
    """
    GPT-2 compatible low-rank linear:
    weight layout = [in_dim, out_dim]
    """
    def __init__(self, Wc, bias, rank):
        super().__init__()

        self.Wc = Wc          # shared base weight
        self.bias = bias

        in_dim, out_dim = Wc.shape   # ðŸ”‘ GPT-2 layout

        device = Wc.device

        # Low-rank delta in SAME layout as Wc
        self.A = nn.Parameter(torch.zeros(in_dim, rank, device=device))
        self.B = nn.Parameter(torch.zeros(rank, out_dim, device=device))

    def forward(self, x):
        W_eff = self.Wc + self.A @ self.B   # shape [in_dim, out_dim]
        return F.linear(x, W_eff.T, self.bias)


    

# ============================================================
# Feature extraction (c_fc activations)
# ============================================================

def extract_ffn_features_single_pass(model, texts):
    model.eval()
    features = {i: [] for i in range(len(model.transformer.h))}
    hooks = []

    def make_hook(i):
        def hook(_, __, out):
            features[i].append(out.detach().cpu())
        return hook

    for i, block in enumerate(model.transformer.h):
        hooks.append(block.mlp.c_fc.register_forward_hook(make_hook(i)))

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    ).to(DEVICE)

    with torch.no_grad():
        model(**enc)

    for h in hooks:
        h.remove()

    X = {}
    for i in features:
        acts = torch.cat(features[i], dim=0)
        acts = acts.reshape(-1, acts.shape[-1])
        X[i] = acts[:MAX_TOKENS].numpy()

    return X

# ============================================================
# Hungarian alignment
# ============================================================

def compute_adjacent_permutations(feature_dict, window):
    anchor = window[0]
    perms = {}
    Xa = feature_dict[anchor]

    for layer in window[1:]:
        Xb = feature_dict[layer]
        C = np.corrcoef(Xa, Xb, rowvar=False)
        d = Xa.shape[1]
        C = C[:d, d:]
        _, col = linear_sum_assignment(-C)
        perms[layer] = col

    return perms

def sliding_windows(n, k):
    return [list(range(i, i + k)) for i in range(n - k + 1)]

# ============================================================
# Replace FFNs with soft shared base + low-rank delta
# ============================================================

def soft_share_ffn_lowrank(model, layers, perms, rank=RANK):
    anchor = layers[0]
    anchor_mlp = model.transformer.h[anchor].mlp

    device = anchor_mlp.c_fc.weight.device

    Wc_fc = nn.Parameter(anchor_mlp.c_fc.weight.data.clone().to(device))
    bc_fc = nn.Parameter(anchor_mlp.c_fc.bias.data.clone().to(device))
    Wc_proj = nn.Parameter(anchor_mlp.c_proj.weight.data.clone().to(device))
    bc_proj = nn.Parameter(anchor_mlp.c_proj.bias.data.clone().to(device))


    for layer in layers:
        mlp = model.transformer.h[layer].mlp
        perm = perms.get(layer, None)

        # c_fc
        mlp.c_fc = LowRankLinear(
        Wc=Wc_fc,
        bias=bc_fc,
        rank=rank
        )

        mlp.c_proj = LowRankLinear(
        Wc=Wc_proj,
        bias=bc_proj,
        rank=rank
        )

# ============================================================
# Loss helpers
# ============================================================

def lm_loss(model, batch):
    input_ids = batch["input_ids"].to(DEVICE)
    attention_mask = batch["attention_mask"].to(DEVICE)

    outputs = model(input_ids)
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:]

    mask = attention_mask[:, 1:].bool()
    log_probs = F.log_softmax(logits, dim=-1)
    token_log_probs = log_probs.gather(
        dim=-1, index=labels.unsqueeze(-1)
    ).squeeze(-1)

    return -(token_log_probs[mask]).mean()

def consensus_loss(model, layers, lam=LAMBDA):
    loss = 0.0
    for l in layers:
        mlp = model.transformer.h[l].mlp
        loss += torch.norm(mlp.c_fc.A @ mlp.c_fc.B, p="fro")**2
        loss += torch.norm(mlp.c_proj.A @ mlp.c_proj.B, p="fro")**2
    return lam * loss

# ============================================================
# Perplexity
# ============================================================

def compute_perplexity(model, texts, batch_size=4):
    model.eval()
    total_nll = 0
    total_tokens = 0

    for i in range(0, len(texts), batch_size):
        enc = tokenizer(
            texts[i:i + batch_size],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(DEVICE)

        with torch.no_grad():
            logits = model(enc["input_ids"]).logits[:, :-1, :]
            labels = enc["input_ids"][:, 1:]
            mask = enc["attention_mask"][:, 1:].bool()

            log_probs = F.log_softmax(logits, dim=-1)
            lp = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)

        total_nll += -(lp[mask]).sum().item()
        total_tokens += mask.sum().item()

    return math.exp(total_nll / total_tokens)

# ============================================================
# Main
# ============================================================

print("Loading data...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
texts = [t for t in dataset["text"] if len(t.strip()) > 0][:EVAL_TEXTS]

print("Loading base model...")
base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)

print("Extracting FFN features...")
features = extract_ffn_features_single_pass(base_model, texts)

windows = sliding_windows(len(base_model.transformer.h), WINDOW_K)

best_ppl = float("inf")
best_window = None

print("\nSearching best window...")
for w in windows:
    test_model = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)
    perms = compute_adjacent_permutations(features, w)
    soft_share_ffn_lowrank(test_model, w, perms)

    ppl = compute_perplexity(test_model, texts)
    print(f"Window {w} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = w

print("\nBest window:", best_window)

# ============================================================
# Fine-tuning
# ============================================================

train_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:2%]")
train_texts = [t for t in train_ds["text"] if len(t.strip()) > 0]

def collate_fn(batch):
    return tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128)

train_loader = DataLoader(train_texts, batch_size=4, shuffle=True, collate_fn=collate_fn)

model = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)
perms = compute_adjacent_permutations(features, best_window)
soft_share_ffn_lowrank(model, best_window, perms)

optimizer = AdamW(model.parameters(), lr=1e-5)

print("\nStarting fine-tuning...")
model.train()
for step, batch in enumerate(train_loader):
    if step >= 300:
        break

    loss = lm_loss(model, batch)
    loss += consensus_loss(model, best_window)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % 50 == 0:
        print(f"Step {step} | Loss {loss.item():.4f}")

# ============================================================
# Final evaluation
# ============================================================

val_ppl = compute_perplexity(model, texts)
print(f"\nFinal Validation PPL: {val_ppl:.2f}")

prompt = "India will become global leader in AI because"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    out = model.generate(**inputs, max_length=40)

print("\nGenerated text:")
print(tokenizer.decode(out[0], skip_special_tokens=True))


Loading data...
Loading base model...
Extracting FFN features...

Searching best window...
Window [0, 1, 2, 3] | PPL 3260.79
Window [1, 2, 3, 4] | PPL 395.99
Window [2, 3, 4, 5] | PPL 128.06
Window [3, 4, 5, 6] | PPL 102.01
Window [4, 5, 6, 7] | PPL 92.25
Window [5, 6, 7, 8] | PPL 91.09
Window [6, 7, 8, 9] | PPL 92.88
Window [7, 8, 9, 10] | PPL 98.46
Window [8, 9, 10, 11] | PPL 350.21

Best window: [5, 6, 7, 8]

Starting fine-tuning...
Step 0 | Loss 6.0397
Step 50 | Loss 4.5996
Step 100 | Loss 4.4414


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Final Validation PPL: 50.86

Generated text:
India will become global leader in AI because of the fact of its ability to communicate with the human.

The AI is a global leader in the world's AI. It is the most advanced AI


In [None]:
test_dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="test"
)

test_texts = [t for t in test_dataset["text"] if len(t.strip()) > 0][:500]
test_ppl = compute_perplexity(model, test_texts)
print(f"TEST PPL after recovery: {test_ppl:.2f}")

TEST PPL after recovery: 68.58
