# DistilBERT Boomerang Distillation Interpolation


*   We evaluate interpolation between DistilBERT and BERT on pseudo-perplexity for a subset of 50 Wikitext documents (note: we use full Wikitext test set for evaluation in the paper)
*   Script takes ~5 minutes to run on a T4 GPU instance



In [None]:
import copy
from datasets import load_dataset
import torch
import torch.nn as nn
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers.modeling_outputs import MaskedLMOutput
import math
import matplotlib.pyplot as plt
import tqdm

## Preliminaries
Setting model names and number of documents to evaluate on

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

N_TEST = 50
DISTIL_NAME = "distilbert-base-uncased"
BERT_NAME = "bert-base-uncased"
if N_TEST is None:
    TEST_DATASET = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
else:
    TEST_DATASET = load_dataset(
        "wikitext", "wikitext-2-raw-v1", split=f"test[:{N_TEST}]"
    )

### Helper function
Utility function for correcting layer index after patching

In [None]:
def set_layer_idx_recursive(block, idx, attr="layer_idx"):
    """Set `attr` on `block` and all its submodules."""
    setattr(block, attr, idx)

    def _set(m):
        setattr(m, attr, idx)

    block.apply(_set)

### Evaluation
Functions to get number of parameters and evaluate perplexity

In [None]:
def get_n_params(model):
    return {
        "non_embedding": sum(
            p.numel() for name, p in model.named_parameters() if "embed" not in name
        ),
        "total": sum(p.numel() for p in model.parameters()),
    }


@torch.no_grad()
def pseudo_perplexity_distilbert(
    model,
    tokenizer,
    text,
    max_length=512,
    chunk_size=64,  # mask this many positions per forward pass
):
    """
    Returns (pseudo_perplexity, token_count_used) for a single string.
    Pseudo-PPL = exp( mean_i  [ -log p(x_i | x_{i}) ] )
    """
    SPECIAL_IDS = {
        tokenizer.cls_token_id,
        tokenizer.sep_token_id,
        tokenizer.pad_token_id,
    }
    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
    )
    input_ids = enc["input_ids"][0].to(device)
    attn = enc["attention_mask"][0].to(device)

    # Choose positions to evaluate (exclude special/pad)
    positions = [
        i
        for i, tid in enumerate(input_ids.tolist())
        if attn[i].item() == 1 and tid not in SPECIAL_IDS
    ]
    if not positions:
        return float("nan"), 0

    total_nll = 0.0
    total_tokens = 0

    # Process in chunks to control memory
    for start in range(0, len(positions), chunk_size):
        chunk = positions[start : start + chunk_size]
        bsz = len(chunk)

        # Build a batch with one masked position per row
        batch_input = input_ids.unsqueeze(0).repeat(bsz, 1)
        batch_attn = attn.unsqueeze(0).repeat(bsz, 1)

        # Replace each row's chosen token with [MASK]
        mask_id = tokenizer.mask_token_id
        for row, pos in enumerate(chunk):
            batch_input[row, pos] = mask_id

        outputs = model(input_ids=batch_input, attention_mask=batch_attn)
        logits = outputs.logits  # [bsz, seq_len, vocab]

        # For row r, take logits at its masked position
        rows = torch.arange(bsz, device=device)
        masked_logits = logits[rows, chunk, :]  # [bsz, vocab]
        log_probs = masked_logits.log_softmax(dim=-1)

        # Gather gold-token log-probs from the *original* tokens
        gold = input_ids[chunk].unsqueeze(1)  # [bsz, 1]
        nll = -log_probs.gather(dim=1, index=gold).squeeze(1)  # [bsz]

        total_nll += nll.sum().item()
        total_tokens += bsz

    ppl = math.exp(total_nll / max(total_tokens, 1))
    return ppl, total_tokens


def pseudo_perplexity_corpus(model, tokenizer):
    """
    Global, token-weighted pseudo-PPL over multiple texts:
    exp( sum_i nll_i / sum_i tokens_i ).
    """
    total_nll, total_tok = 0.0, 0

    test_texts = [t["text"] for t in TEST_DATASET]

    for t in tqdm.tqdm(test_texts):
        enc = tokenizer(t, return_tensors="pt", truncation=True, max_length=512)
        # Reuse the single-string function but also get its internals
        ppl, tok = pseudo_perplexity_distilbert(model, tokenizer, t)
        if tok > 0:
            # Convert back to summed NLL to combine correctly:
            total_nll += math.log(ppl) * tok
            total_tok += tok
    return math.exp(total_nll / max(total_tok, 1)) if total_tok else float("nan")

## Interpolation code
Model class for creating interpolated BERT models

In [None]:
class DistilThenInterpolateForMaskedLM(nn.Module):
    def __init__(
        self,
        distil_name=DISTIL_NAME,
        base_name=BERT_NAME,
        n_layers_removed=2,
        device="cuda",
    ):
        super().__init__()
        self.device = device

        # DistilBERT with MLM head (we only use its encoder)
        self.distil_mlm = AutoModelForMaskedLM.from_pretrained(distil_name).to(device)
        # BERT with MLM head (we use its encoder layers AND its MLM head)
        self.bert_mlm = AutoModelForMaskedLM.from_pretrained(base_name).to(device)

        self.distil_mlm.distilbert.transformer.layer = (
            self.distil_mlm.distilbert.transformer.layer[:-n_layers_removed]
        )

        # layers used to initialize distilBERT (see https://github.com/huggingface/transformers-research-projects/blob/362a490dc36e91359fe76a7a707dc29e663196b2/distillation/scripts/extract_distilbert.py#L55C9-L55C43)
        teacher_layers = [0, 2, 4, 7, 9, 11]
        n_layers = teacher_layers[-n_layers_removed]

        # Select N BertLayer modules to append
        bert_layers = self.bert_mlm.bert.encoder.layer  # ModuleList[BertLayer]
        if n_layers <= 0:
            selected = []
        else:
            selected = [copy.deepcopy(l) for l in bert_layers[n_layers:]]
        self.append_layers = nn.ModuleList(selected)
        for i, block in enumerate(self.append_layers):
            set_layer_idx_recursive(
                block, i + len(self.distil_mlm.distilbert.transformer.layer)
            )
        self.distil_mlm.config.n_layers = len(
            self.distil_mlm.distilbert.transformer.layer
        ) + len(self.append_layers)
        self.config = self.distil_mlm.config
        self.config.tie_weights_ = False

    def forward(self, input_ids, attention_mask=None, labels=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        # 1) DistilBERT encoder
        distil_out = self.distil_mlm.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        hidden_states = distil_out.last_hidden_state  # [B, T, 768]

        # 2) Append selected BERT encoder layers
        if len(self.append_layers) > 0:
            ext_mask = self.bert_mlm.bert.get_extended_attention_mask(
                attention_mask, input_shape=input_ids.shape, device=input_ids.device
            )
            for layer in self.append_layers:
                out = layer(
                    hidden_states, attention_mask=ext_mask, output_attentions=False
                )
                hidden_states = out[0] if isinstance(out, (tuple, list)) else out

        # 3) Use **BERT's MLM head** on the final hidden states
        logits = self.bert_mlm.cls(hidden_states)  # [B, T, vocab_size]

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return MaskedLMOutput(
            loss=loss, logits=logits, hidden_states=None, attentions=None
        )

## Run evaluation

Code to run evaluation on DistilBERT, BERT, and interpolated models



In [None]:
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_NAME)
distilbert = AutoModelForMaskedLM.from_pretrained(DISTIL_NAME).eval().to(device)
bert = AutoModelForMaskedLM.from_pretrained(BERT_NAME).eval().to(device)
# layers used to initialize distilBERT (see https://github.com/huggingface/transformers-research-projects/blob/362a490dc36e91359fe76a7a707dc29e663196b2/distillation/scripts/extract_distilbert.py#L55C9-L55C43)
teacher_layers = [0, 2, 4, 7, 9, 11]

In [None]:
n_params_lst_bert = []
perplexity_dct_bert = {"zero-shot": [], "naive pruned": []}

print("Evaluating DistilBERT")
n_params_lst_bert.append(get_n_params(distilbert)["total"] / 1e6)
perplexity_dct_bert["zero-shot"].append(
    pseudo_perplexity_corpus(distilbert, bert_tokenizer)
)

print("Evaluating Interpolated Models")
for n in range(2, 6):
    interp_model = DistilThenInterpolateForMaskedLM(n_layers_removed=n, device=device)
    n_params_lst_bert.append(
        (
            get_n_params(interp_model.distil_mlm)["total"]
            + get_n_params(interp_model.append_layers)["total"]
        )
        / 1e6
    )
    perplexity_dct_bert["zero-shot"].append(
        pseudo_perplexity_corpus(interp_model, bert_tokenizer)
    )
    del interp_model

print("Evaluating Naive Pruned Models")
interp_model = copy.deepcopy(bert)
for n in range(1, 6):
    interp_model.eval()
    for i in range(teacher_layers[-n] - 1, teacher_layers[-(n + 1)], -1):
        del interp_model.bert.encoder.layer[i]
    interp_model.config.num_hidden_layers = len(interp_model.bert.encoder.layer)
    interp_model.eval()
    perplexity_dct_bert["naive pruned"].append(
        pseudo_perplexity_corpus(interp_model, bert_tokenizer)
    )

perplexity_dct_bert["naive pruned"] = perplexity_dct_bert["naive pruned"][::-1]
del interp_model

print("Evaluating BERT")
n_params_lst_bert.append(get_n_params(bert)["total"] / 1e6)
for p in perplexity_dct_bert:
    perplexity_dct_bert[p].append(pseudo_perplexity_corpus(bert, bert_tokenizer))

## Plotting

Plot performance of interpolated models compared to naive layer pruning

In [None]:
mname = "BERT"
plt.figure(figsize=(8, 6))

# Create plot
for metric in perplexity_dct_bert:
    metric_lst = perplexity_dct_bert[metric]
    plt.plot(
        n_params_lst_bert,
        metric_lst,
        marker="o",
        linestyle="-",
        label=f"{metric.title()} Models",
    )
plt.plot(
    n_params_lst_bert[0],
    metric_lst[0],
    marker="o",
    color="orange",
    label=f"Distil{mname}",
)
plt.plot(
    n_params_lst_bert[-1], metric_lst[-1], marker="o", color="red", label=f"{mname}"
)

plt.xlabel("Parameter count (millions)")
plt.ylabel("MLM Pseudo-perplexity")
plt.title("MLM Pseudo-perplexity vs. Model Size")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(loc="best")
plt.show()