In [1]:
import torch, numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from datasets import load_dataset
from scipy.optimize import linear_sum_assignment

In [2]:
import torch.nn as nn


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [4]:
def extract_ffn_features_single_pass(model, texts, max_tokens=4000):
    """
    Extract FFN pre-activation features (c_fc outputs)
    for all layers from a SINGLE forward pass,
    excluding padding tokens.
    """
    model.eval()
    features = {i: [] for i in range(len(model.transformer.h))}
    hooks = []

    def make_hook(layer_idx):
        def hook(module, inp, out):
            # out: [batch, seq, d_ff]
            features[layer_idx].append(out.detach().cpu())
        return hook

    # Register hooks ONLY on c_fc
    for i, block in enumerate(model.transformer.h):
        hooks.append(
            block.mlp.c_fc.register_forward_hook(make_hook(i))
        )

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    ).to(device)

    attention_mask = enc["attention_mask"]  # [B, T]

    with torch.no_grad():
        model(**enc)

    for h in hooks:
        h.remove()

    X = {}
    for i in features:
        # [B, T, d]
        # acts = torch.cat(features[i], dim=0)

        # # mask out padding tokens
        # mask = attention_mask.bool().unsqueeze(-1)  # [B, T, 1]
        # acts = acts[mask.expand_as(acts)].view(-1, acts.shape[-1])

        
        acts = torch.cat(features[i], dim=0)  # acts is on CPU

        mask = attention_mask.bool().unsqueeze(-1).cpu()  # <-- FIX

        acts = acts[mask.expand_as(acts)].view(-1, acts.shape[-1])

        if acts.shape[0] > max_tokens:
            acts = acts[:max_tokens]

        X[i] = acts.numpy()

    return X


In [5]:
from scipy.optimize import linear_sum_assignment
import numpy as np

def compute_adjacent_permutations(feature_dict, window_layers):
    """
    Compute permutations ONLY between anchor and adjacent layers
    inside a window.
    """
    anchor = window_layers[0]
    perms = {}

    X_anchor = feature_dict[anchor]

    for layer in window_layers[1:]:
        X_other = feature_dict[layer]

        # Pearson correlation (Eq. 1 in paper)
        C = np.corrcoef(X_anchor, X_other, rowvar=False)
        d = X_anchor.shape[1]
        C = C[:d, d:]

        _, col_ind = linear_sum_assignment(-C)
        perms[layer] = col_ind

    return perms


In [6]:
def sliding_windows(n_layers, k):
    """
    Generate sliding windows of k adjacent layers.
    """
    return [list(range(i, i + k)) for i in range(n_layers - k + 1)]


In [7]:
def permute_ffn(model, layer, perm):
    block = model.transformer.h[layer].mlp
    perm = torch.tensor(perm, dtype=torch.long, device=device)

    # GPT-2 weight shapes are transposed in your setup:
    # c_fc:   [768, 3072]
    # c_proj: [3072, 768]
    with torch.no_grad():
        block.c_fc.weight[:] = block.c_fc.weight[:, perm]
        block.c_fc.bias[:]   = block.c_fc.bias[perm]
        block.c_proj.weight[:] = block.c_proj.weight[perm, :]

In [8]:
import torch
import torch.nn as nn

def merge_k_layers(model, layers, perms):
    """
    Paper-faithful FFN merging with TRUE weight tying.
    - No in-place permutation
    - Logical permutation inside averaging
    - Shared nn.Parameter objects
    """

    anchor = layers[0]
    anchor_mlp = model.transformer.h[anchor].mlp

    device = anchor_mlp.c_fc.weight.device
    k = len(layers)

    # Start from anchor weights
    W_in  = anchor_mlp.c_fc.weight.data.clone()
    b_in  = anchor_mlp.c_fc.bias.data.clone()
    W_out = anchor_mlp.c_proj.weight.data.clone()
    b_out = anchor_mlp.c_proj.bias.data.clone()

    # Accumulate aligned weights (NO mutation)
    for layer in layers[1:]:
        mlp = model.transformer.h[layer].mlp
        perm = torch.tensor(perms[layer], device=device)

        # Logical permutation (as in Eq. 3–6)
        W_in  += mlp.c_fc.weight[:, perm]
        b_in  += mlp.c_fc.bias[perm]
        W_out += mlp.c_proj.weight[perm, :]
        b_out += mlp.c_proj.bias

    # Average
    W_in  /= k
    b_in  /= k
    W_out /= k
    b_out /= k

    # Create SHARED parameters (true tying)
    shared_c_fc_weight   = nn.Parameter(W_in)
    shared_c_fc_bias     = nn.Parameter(b_in)
    shared_c_proj_weight = nn.Parameter(W_out)
    shared_c_proj_bias   = nn.Parameter(b_out)

    # Tie all layers in the window
    for layer in layers:
        mlp = model.transformer.h[layer].mlp

        mlp.c_fc.weight = shared_c_fc_weight
        mlp.c_fc.bias   = shared_c_fc_bias

        mlp.c_proj.weight = shared_c_proj_weight
        mlp.c_proj.bias   = shared_c_proj_bias


In [9]:
import math
import torch

def compute_perplexity(model, texts, batch_size=4, max_length=128):
    model.eval()
    total_nll = 0.0
    total_tokens = 0

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]

        enc = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, :-1, :]
            labels = input_ids[:, 1:]

        # mask padding tokens
        mask = attention_mask[:, 1:].bool()

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(
            dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)

        total_nll += -(token_log_probs[mask]).sum().item()
        total_tokens += mask.sum().item()

    return math.exp(total_nll / total_tokens)

# # merged_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
# # merge_k_layers(merged_model, window, perms)

# merged_ppl = compute_perplexity(model, texts)

# print(f"Merged (window={window}) PPL: {merged_ppl:.2f}")


In [10]:
dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="validation"
)
texts = [t for t in dataset["text"] if len(t.strip()) > 0][:200]
base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
base_ppl = compute_perplexity(base_model, texts)

print(f"Baseline GPT-2 PPL: {base_ppl:.2f}")

Baseline GPT-2 PPL: 46.98


In [11]:
# print("Before recovery PPL:",
#           compute_perplexity(model, texts))   

In [12]:
def print_num_layers(model):
    num_layers = len(model.transformer.h)
    print(f"Number of transformer layers: {num_layers}")
print_num_layers(base_model)

Number of transformer layers: 12


In [None]:
# def count_unique_params(model):
#     seen = set()
#     total = 0
#     for p in model.parameters():
#         if id(p) not in seen:
#             seen.add(id(p))
#             total += p.numel()
#     return total
# def count_params_naive(model):
#     return sum(p.numel() for p in model.parameters())
# def print_param_stats(model, name="model"):
#     naive = count_params_naive(model)
#     unique = count_unique_params(model)
#     print(f"{name}")
#     print(f"  Naive params  : {naive:,}")
#     print(f"  Unique params : {unique:,}")
#     print(f"  Saved params  : {naive - unique:,}")
# print_param_stats(base_model, "Original GPT-2")
# print_param_stats(model, "Merged Model")

Original GPT-2
  Naive params  : 124,439,808
  Unique params : 124,439,808
  Saved params  : 0
Merged Model
  Naive params  : 119,717,376
  Unique params : 119,717,376
  Saved params  : 0


In [13]:
train_dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="train[:2%]"
)

train_texts = [t for t in train_dataset["text"] if len(t.strip()) > 0]


In [14]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    enc = tokenizer(
        batch,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    )
    return enc

train_loader = DataLoader(
    train_texts,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)


In [15]:
import torch.nn.functional as F
from torch.optim import AdamW

def recovery_finetune(
    model,
    train_loader,
    steps=1000,
    lr=1e-5,
    weight_decay=0.01,
    device="cuda"
):
    model.train()

    optimizer = AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    step = 0
    running_loss = 0.0

    for batch in train_loader:
        if step >= steps:
            break

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(input_ids)
        logits = outputs.logits[:, :-1, :]
        labels = input_ids[:, 1:]

        mask = attention_mask[:, 1:].bool()

        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(
            dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)

        loss = -(token_log_probs[mask]).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        step += 1

        if step % 100 == 0:
            print(f"[Recovery] step {step} | loss {running_loss / 100:.4f}")
            running_loss = 0.0

    model.eval()


In [16]:
base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

# 1) Extract features ONCE (same forward pass)
feature_mats = extract_ffn_features_single_pass(
    base_model, texts
)

In [17]:
best_ppl = float("inf")
best_window = None
best_state_dict = None

k = 2
windows = sliding_windows(len(base_model.transformer.h), k)

model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
print("Before merging PPL:",
          compute_perplexity(model, texts))

for window in windows:
    model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

    perms = compute_adjacent_permutations(feature_mats, window)
    merge_k_layers(model, window, perms)

    ppl = compute_perplexity(model, texts)
    print(f"Window {window} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = window
        best_state_dict = {
            k: v.detach().cpu().clone()
            for k, v in model.state_dict().items()
        }


best_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
best_model.load_state_dict(best_state_dict)

print(f"\nBEST WINDOW: {best_window}")
print(f"BEST PPL AFTER MERGING: {best_ppl:.2f}")


Before merging PPL: 46.97976365832357
Window [0, 1] | PPL 196.45
Window [1, 2] | PPL 101.00
Window [2, 3] | PPL 86.62
Window [3, 4] | PPL 74.25
Window [4, 5] | PPL 80.09
Window [5, 6] | PPL 92.55
Window [6, 7] | PPL 60.75
Window [7, 8] | PPL 58.43
Window [8, 9] | PPL 57.56
Window [9, 10] | PPL 61.13
Window [10, 11] | PPL 66.76

BEST WINDOW: [8, 9]
BEST PPL AFTER MERGING: 57.56


In [18]:
train_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
train_model.load_state_dict(best_state_dict)

<All keys matched successfully>

In [19]:
print("Before training PPL:",
          compute_perplexity(train_model, texts))

recovery_finetune(
    train_model,
    train_loader,
    steps=10,   # start small
    lr=1e-5,
    device=device
)

print("After training PPL:",
        compute_perplexity(train_model, texts))

Before training PPL: 57.563756951116204
After training PPL: 50.988636270115435


In [20]:
# ---- TEST EVALUATION ----
test_dataset = load_dataset(
    "wikitext", "wikitext-2-raw-v1", split="test"
)

test_texts = [t for t in test_dataset["text"] if len(t.strip()) > 0][:500]

val_ppl = compute_perplexity(train_model, texts)
print(f"Validation PPL after recovery: {val_ppl:.2f}")

test_ppl = compute_perplexity(train_model, test_texts)
print(f"TEST PPL after recovery: {test_ppl:.2f}")


Validation PPL after recovery: 50.99
TEST PPL after recovery: 70.06


In [21]:
print("After training PPL:",
        compute_perplexity(train_model, train_texts))

After training PPL: 66.83528990070141


In [22]:
prompt = "India will become global leader in AI because 1. it has"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    sample_out = train_model.generate(**inputs, max_length=40)
print("\nGenerated output:")
print(tokenizer.decode(sample_out[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Generated output:
India will become global leader in AI because 1. it has the ability to make decisions in the world and 2. it has the ability to make decisions in the world.

The AI system is


In [23]:
l0 = model.transformer.h[window[0]].mlp
l1 = model.transformer.h[window[1]].mlp

print(l0.c_fc.weight is l1.c_fc.weight)   # MUST be True
print(l0.c_proj.weight is l1.c_proj.weight)  # MUST be True


True
True


In [25]:
def count_unique_params(model):
    seen = set()
    total = 0
    for p in model.parameters():
        if id(p) not in seen:
            seen.add(id(p))
            total += p.numel()
    return total
def count_params_naive(model):
    return sum(p.numel() for p in model.parameters())
def print_param_stats(model, name="model"):
    naive = count_params_naive(model)
    unique = count_unique_params(model)
    print(f"{name}")
    print(f"  Naive params  : {naive:,}")
    print(f"  Unique params : {unique:,}")
    print(f"  Saved params  : {naive - unique:,}")
print_param_stats(base_model, "Original GPT-2")
print_param_stats(model, "Merged Model")


Original GPT-2
  Naive params  : 124,439,808
  Unique params : 124,439,808
  Saved params  : 0
Merged Model
  Naive params  : 119,717,376
  Unique params : 119,717,376
  Saved params  : 0


In [26]:
best_ppl = float("inf")
best_window = None
best_model = None   # <-- store MODEL OBJECT, not state_dict

k = 2
windows = sliding_windows(len(base_model.transformer.h), k)

for window in windows:
    model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

    perms = compute_adjacent_permutations(feature_mats, window)
    merge_k_layers(model, window, perms)

    ppl = compute_perplexity(model, texts)
    print(f"Window {window} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = window
        best_model = model   # ✅ keep tied model alive


Window [0, 1] | PPL 196.45
Window [1, 2] | PPL 101.00
Window [2, 3] | PPL 86.62
Window [3, 4] | PPL 74.25
Window [4, 5] | PPL 80.09
Window [5, 6] | PPL 92.55
Window [6, 7] | PPL 60.75
Window [7, 8] | PPL 58.43
Window [8, 9] | PPL 57.56
Window [9, 10] | PPL 61.13
Window [10, 11] | PPL 66.76


In [27]:
l0 = best_model.transformer.h[best_window[0]].mlp
l1 = best_model.transformer.h[best_window[1]].mlp

print("Weights tied?",
      l0.c_fc.weight is l1.c_fc.weight,
      l0.c_proj.weight is l1.c_proj.weight)


Weights tied? True True


In [28]:
print("Before recovery PPL:",
      compute_perplexity(best_model, texts))

recovery_finetune(
    best_model,
    train_loader,
    steps=10,
    lr=1e-5,
    device=device
)

print("After recovery PPL:",
      compute_perplexity(best_model, texts))


Before recovery PPL: 57.563756951116204
After recovery PPL: 50.34575648285173


In [None]:
l0 = best_model.transformer.h[best_window[0]].mlp
l1 = best_model.transformer.h[best_window[1]].mlp
# print(l0.c_fc.weight)
# print(l1.c_fc.weight)
print(l0.c_fc.weight is l1.c_fc.weight)   # MUST be True
print(l0.c_proj.weight is l1.c_proj.weight)  # MUST be True

Parameter containing:
tensor([[-0.0113,  0.0835,  0.0336,  ..., -0.0634, -0.1078, -0.0648],
        [-0.0199, -0.0391,  0.0831,  ..., -0.1936,  0.0187,  0.0800],
        [ 0.1375, -0.0130,  0.0311,  ...,  0.0701, -0.0677,  0.1381],
        ...,
        [ 0.0741, -0.0061, -0.1858,  ...,  0.0620,  0.0780,  0.1476],
        [ 0.0216,  0.1608,  0.0733,  ..., -0.0876,  0.0063,  0.0154],
        [-0.0899,  0.0435,  0.0539,  ..., -0.0951,  0.1215,  0.0606]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[-0.0113,  0.0835,  0.0336,  ..., -0.0634, -0.1078, -0.0648],
        [-0.0199, -0.0391,  0.0831,  ..., -0.1936,  0.0187,  0.0800],
        [ 0.1375, -0.0130,  0.0311,  ...,  0.0701, -0.0677,  0.1381],
        ...,
        [ 0.0741, -0.0061, -0.1858,  ...,  0.0620,  0.0780,  0.1476],
        [ 0.0216,  0.1608,  0.0733,  ..., -0.0876,  0.0063,  0.0154],
        [-0.0899,  0.0435,  0.0539,  ..., -0.0951,  0.1215,  0.0606]],
       device='cuda:0', requires_grad=True)


In [29]:
def count_unique_params(model):
    seen = set()
    total = 0
    for p in model.parameters():
        if id(p) not in seen:
            seen.add(id(p))
            total += p.numel()
    return total

def count_params_naive(model):
    return sum(p.numel() for p in model.parameters())

print("Original GPT-2")
print(" Naive :", count_params_naive(base_model))
print(" Unique:", count_unique_params(base_model))

print("\nMerged (tied) model")
print(" Naive :", count_params_naive(best_model))
print(" Unique:", count_unique_params(best_model))


Original GPT-2
 Naive : 124439808
 Unique: 124439808

Merged (tied) model
 Naive : 119717376
 Unique: 119717376


In [26]:
!pip install transformers datasets scipy numpy accelerate sentencepiece


Collecting accelerate
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.2.1-cp310-cp310-win_amd64.whl.metadata (10 kB)
Downloading accelerate-1.12.0-py3-none-any.whl (380 kB)
Downloading sentencepiece-0.2.1-cp310-cp310-win_amd64.whl (1.1 MB)
   ---------------------------------------- 0.0/1.1 MB ? eta -:--:--
   ------------------- -------------------- 0.5/1.1 MB 4.2 MB/s eta 0:00:01
   ---------------------------------------- 1.1/1.1 MB 3.6 MB/s  0:00:00
Installing collected packages: sentencepiece, accelerate

   -------------------- ------------------- 1/2 [accelerate]
   -------------------- ------------------- 1/2 [accelerate]
   -------------------- ------------------- 1/2 [accelerate]
   -------------------- ------------------- 1/2 [accelerate]
   ---------------------------------------- 2/2 [accelerate]

Successfully installed accelerate-1.12.0 sentencepiece-0.2.1


In [1]:
# ============================================================
# LLaMA FFN Neuron Alignment + Weight-Tied Merging
# Works for: LLaMA, TinyLlama, Mistral, Gemma, Qwen
# ============================================================

import torch
import torch.nn as nn
import numpy as np
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from scipy.optimize import linear_sum_assignment

# -------------------------
# Config
# -------------------------
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
MAX_TOKENS = 4000
K = 2   # merge window size

# -------------------------
# Load model & tokenizer
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto"
).eval()

# -------------------------
# Utilities
# -------------------------
def get_blocks(model):
    return model.model.layers

def get_ffn(block):
    return block.mlp  # gate_proj, up_proj, down_proj

# -------------------------
# FFN Feature Extraction (UP projection)
# -------------------------
def extract_ffn_features_single_pass(model, texts, max_tokens=MAX_TOKENS):
    blocks = get_blocks(model)
    features = {i: [] for i in range(len(blocks))}
    hooks = []

    def make_hook(i):
        def hook(_, __, out):
            features[i].append(out.detach().cpu())
        return hook

    for i, block in enumerate(blocks):
        hooks.append(
            block.mlp.up_proj.register_forward_hook(make_hook(i))
        )

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    ).to(DEVICE)

    with torch.no_grad():
        model(**enc)

    for h in hooks:
        h.remove()

    X = {}
    for i in features:
        acts = torch.cat(features[i], dim=0)
        if acts.shape[0] > max_tokens:
            acts = acts[:max_tokens]
        X[i] = acts.numpy()

    return X

# -------------------------
# Hungarian Neuron Matching
# -------------------------
def compute_adjacent_permutations(feature_dict, window):
    anchor = window[0]
    perms = {}

    X_anchor = feature_dict[anchor]

    for layer in window[1:]:
        X_other = feature_dict[layer]
        C = np.corrcoef(X_anchor, X_other, rowvar=False)
        d = X_anchor.shape[1]
        C = C[:d, d:]
        _, col_ind = linear_sum_assignment(-C)
        perms[layer] = col_ind

    return perms

# -------------------------
# Sliding windows
# -------------------------
def sliding_windows(n_layers, k):
    return [list(range(i, i + k)) for i in range(n_layers - k + 1)]

# -------------------------
# LLaMA FFN Merge (TRUE weight tying)
# -------------------------
def merge_k_layers_llama(model, layers, perms):
    blocks = get_blocks(model)
    anchor = layers[0]
    anchor_mlp = blocks[anchor].mlp

    device = anchor_mlp.up_proj.weight.device
    k = len(layers)

    W_gate = anchor_mlp.gate_proj.weight.data.clone()
    W_up   = anchor_mlp.up_proj.weight.data.clone()
    W_down = anchor_mlp.down_proj.weight.data.clone()

    for layer in layers[1:]:
        mlp = blocks[layer].mlp
        perm = torch.tensor(perms[layer], device=device)

        W_gate += mlp.gate_proj.weight[perm, :]
        W_up   += mlp.up_proj.weight[perm, :]
        W_down += mlp.down_proj.weight[:, perm]

    W_gate /= k
    W_up   /= k
    W_down /= k

    shared_gate = nn.Parameter(W_gate)
    shared_up   = nn.Parameter(W_up)
    shared_down = nn.Parameter(W_down)

    for layer in layers:
        mlp = blocks[layer].mlp
        mlp.gate_proj.weight = shared_gate
        mlp.up_proj.weight   = shared_up
        mlp.down_proj.weight = shared_down

# -------------------------
# Perplexity
# -------------------------
def compute_perplexity(model, texts, batch_size=4):
    model.eval()
    total_nll = 0.0
    total_tokens = 0

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        enc = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(DEVICE)

        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, :-1, :]
            labels = input_ids[:, 1:]

        mask = attention_mask[:, 1:].bool()
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(
            dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)

        total_nll += -(token_log_probs[mask]).sum().item()
        total_tokens += mask.sum().item()

    return math.exp(total_nll / total_tokens)

# ============================================================
# Main Experiment
# ============================================================

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
texts = [t for t in dataset["text"] if len(t.strip()) > 0][:200]

base_ppl = compute_perplexity(model, texts)
print(f"Baseline PPL: {base_ppl:.2f}")

# Extract features ONCE
feature_mats = extract_ffn_features_single_pass(model, texts)

best_ppl = float("inf")
best_window = None
best_state = None

windows = sliding_windows(len(get_blocks(model)), K)

for window in windows:
    test_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=DTYPE,
        device_map="auto"
    )

    perms = compute_adjacent_permutations(feature_mats, window)
    merge_k_layers_llama(test_model, window, perms)

    ppl = compute_perplexity(test_model, texts)
    print(f"Window {window} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = window
        best_state = {
            k: v.detach().cpu().clone()
            for k, v in test_model.state_dict().items()
        }

print("\n==============================")
print(f"BEST WINDOW: {best_window}")
print(f"BEST PPL   : {best_ppl:.2f}")
print("==============================")

# -------------------------
# Sanity: weight tying
# -------------------------
final_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto"
)
final_model.load_state_dict(best_state)

l0 = final_model.model.layers[best_window[0]].mlp
l1 = final_model.model.layers[best_window[1]].mlp

print("Weight tied?",
      l0.up_proj.weight is l1.up_proj.weight,
      l0.down_proj.weight is l1.down_proj.weight)

# -------------------------
# Generation demo
# -------------------------
prompt = "India will become global leader in AI because"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    out = final_model.generate(**inputs, max_length=40)

print("\nGenerated text:")
print(tokenizer.decode(out[0], skip_special_tokens=True))


`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Baseline PPL: 14.78


Some parameters are on the meta device because they were offloaded to the cpu.


ValueError: m has more than 2 dimensions

In [3]:
# ============================================================
# LLaMA / TinyLlama FFN Neuron Alignment + Weight-Tied Merging
# FINAL DEVICE-SAFE VERSION
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from scipy.optimize import linear_sum_assignment

# -------------------------
# Config
# -------------------------
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MAX_TOKENS = 4000
K = 2
EVAL_TEXTS = 200

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

# -------------------------
# Tokenizer
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# -------------------------
# Model loaders
# -------------------------
def load_model_cpu():
    """Used ONLY for feature extraction"""
    return AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.float32,
        device_map=None
    ).eval()

def load_model_eval():
    """Used for merging + perplexity"""
    return AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=DTYPE,
        device_map="auto"
    ).eval()

# -------------------------
# Helpers
# -------------------------
def get_blocks(model):
    return model.model.layers

# -------------------------
# FFN Feature Extraction (CPU ONLY)
# -------------------------
def extract_ffn_features_single_pass(model, texts, max_tokens=MAX_TOKENS):
    blocks = get_blocks(model)
    features = {i: [] for i in range(len(blocks))}
    hooks = []

    def make_hook(i):
        def hook(_, __, out):
            features[i].append(out.detach())
        return hook

    for i, block in enumerate(blocks):
        hooks.append(
            block.mlp.up_proj.register_forward_hook(make_hook(i))
        )

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    )  # ❗ stays on CPU

    with torch.no_grad():
        model(**enc)

    for h in hooks:
        h.remove()

    X = {}
    for i in features:
        acts = torch.cat(features[i], dim=0)     # [B, T, d_ff]
        acts = acts.reshape(-1, acts.shape[-1])  # [B*T, d_ff]

        if acts.shape[0] > max_tokens:
            acts = acts[:max_tokens]

        X[i] = acts.numpy()

    return X

# -------------------------
# Hungarian matching
# -------------------------
def compute_adjacent_permutations(feature_dict, window):
    anchor = window[0]
    perms = {}

    Xa = feature_dict[anchor]

    for layer in window[1:]:
        Xb = feature_dict[layer]

        C = np.corrcoef(Xa, Xb, rowvar=False)
        d = Xa.shape[1]
        C = C[:d, d:]

        _, col_ind = linear_sum_assignment(-C)
        perms[layer] = col_ind

    return perms

# -------------------------
# Sliding windows
# -------------------------
def sliding_windows(n_layers, k):
    return [list(range(i, i + k)) for i in range(n_layers - k + 1)]

# -------------------------
# LLaMA FFN merge (true tying)
# -------------------------
def merge_k_layers_llama(model, layers, perms):
    blocks = get_blocks(model)
    anchor = layers[0]
    anchor_mlp = blocks[anchor].mlp

    device = anchor_mlp.up_proj.weight.device
    k = len(layers)

    W_gate = anchor_mlp.gate_proj.weight.data.clone()
    W_up   = anchor_mlp.up_proj.weight.data.clone()
    W_down = anchor_mlp.down_proj.weight.data.clone()

    for layer in layers[1:]:
        mlp = blocks[layer].mlp
        perm = torch.tensor(perms[layer], device=device)

        W_gate += mlp.gate_proj.weight[perm, :]
        W_up   += mlp.up_proj.weight[perm, :]
        W_down += mlp.down_proj.weight[:, perm]

    W_gate /= k
    W_up   /= k
    W_down /= k

    shared_gate = nn.Parameter(W_gate)
    shared_up   = nn.Parameter(W_up)
    shared_down = nn.Parameter(W_down)

    for layer in layers:
        mlp = blocks[layer].mlp
        mlp.gate_proj.weight = shared_gate
        mlp.up_proj.weight   = shared_up
        mlp.down_proj.weight = shared_down

# -------------------------
# Perplexity
# -------------------------
def compute_perplexity(model, texts, batch_size=4):
    model.eval()
    total_nll = 0.0
    total_tokens = 0

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]

        enc = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(DEVICE)

        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, :-1, :]
            labels = input_ids[:, 1:]

        mask = attention_mask[:, 1:].bool()

        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(
            dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)

        total_nll += -(token_log_probs[mask]).sum().item()
        total_tokens += mask.sum().item()

    return math.exp(total_nll / total_tokens)

# ============================================================
# Main
# ============================================================

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
texts = [t for t in dataset["text"] if len(t.strip()) > 0][:EVAL_TEXTS]

print("Loading CPU model for feature extraction...")
cpu_model = load_model_cpu()

print("Extracting FFN features (CPU)...")
feature_mats = extract_ffn_features_single_pass(cpu_model, texts)

print("Loading eval model...")
base_model = load_model_eval()
base_ppl = compute_perplexity(base_model, texts)
print(f"Baseline PPL: {base_ppl:.2f}")

best_ppl = float("inf")
best_window = None
best_state = None

windows = sliding_windows(len(get_blocks(base_model)), K)

for window in windows:
    test_model = load_model_eval()
    perms = compute_adjacent_permutations(feature_mats, window)
    merge_k_layers_llama(test_model, window, perms)

    ppl = compute_perplexity(test_model, texts)
    print(f"Window {window} | PPL {ppl:.2f}")

    if ppl < best_ppl:
        best_ppl = ppl
        best_window = window
        best_state = {
            k: v.detach().cpu().clone()
            for k, v in test_model.state_dict().items()
        }

print("\n==============================")
print(f"BEST WINDOW: {best_window}")
print(f"BEST PPL   : {best_ppl:.2f}")
print("==============================")

# -------------------------
# Sanity check: weight tying
# -------------------------
final_model = load_model_eval()
final_model.load_state_dict(best_state)

l0 = final_model.model.layers[best_window[0]].mlp
l1 = final_model.model.layers[best_window[1]].mlp

print("Weights tied?",
      l0.up_proj.weight is l1.up_proj.weight,
      l0.down_proj.weight is l1.down_proj.weight)

# -------------------------
# Generation demo
# -------------------------
prompt = "India will become global leader in AI because"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    out = final_model.generate(**inputs, max_length=40)

print("\nGenerated text:")
print(tokenizer.decode(out[0], skip_special_tokens=True))


Loading CPU model for feature extraction...
Extracting FFN features (CPU)...
Loading eval model...
Baseline PPL: 14.78


Some parameters are on the meta device because they were offloaded to the cpu.


Window [0, 1] | PPL 2306.26


NotImplementedError: Cannot copy out of meta tensor; no data!