# Setup

In [None]:
import torch as t
import transformers

In [None]:
DEVICE = "cuda" if t.cuda.is_available() else "cpu"
DEVICE

In [None]:
gpt = transformers.GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)

In [None]:
tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")

# The bias term in the final LayerNorm encodes unigram statistics

In [None]:
def summarize_unigrams(gpt2, extra_tokens_to_show=None):
    with t.no_grad():
        unigrams = gpt2.lm_head(gpt2.transformer.ln_f.bias)
        probabilities = t.softmax(unigrams, dim=-1)
        if extra_tokens_to_show is None:
            extra_tokens_to_show = []
        tokens_to_show = list(t.sort(unigrams, descending=True).indices[:10]) + extra_tokens_to_show
    return [(tokenizer.decode(token), f"{probabilities[token].item():0.2}") for token in tokens_to_show]

summarize_unigrams(gpt)

These look like plausible unigram statistics, but that doesn't tell me if they're *exactly* unigram statistics. In fact we expect them to be a good low-rank approximation to unigram statistics, since we don't have enough dimensions in the bias to encode the correct numbers.

So, the claim instead is something like "everything the model knows about unigram statistics, it knows via this bias term". I'm a bit confused about what exactly it would mean for this to be true, though.

In [None]:
def gpt2_with_tweaked_token(token, tweak, gpt2: transformers.GPT2LMHeadModel = None):
    if gpt2 is None:
        gpt2 = transformers.GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)
    unigrams = gpt2.lm_head(gpt2.transformer.ln_f.bias)
    unigrams[token] += tweak
    new_bias = t.linalg.lstsq(gpt2.lm_head.weight, unigrams).solution
    assert new_bias.shape == gpt2.transformer.ln_f.bias.shape
    gpt2.transformer.ln_f.bias = t.nn.Parameter(new_bias)
    return gpt2

This is a cute way of adjusting unigram statistics, but the resulting generated text can often end up being degenerate.

While the above uses the assumption that minimizing least squares error is a good way to avoid corrupting unigram statistics that we don't want to change, and this assumption is a bit suspicious, it doesn't affect how legit the claim about the bias being the model's whole knowledge of unigram statistics.