In [8]:
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2Model

def load_model_and_embeddings(model_name="gpt2"):
    """
    Load a GPT-2 model and return its tokenizer and embedding matrix.
    """
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2Model.from_pretrained(model_name)
    # Safely get the input-embedding layer:
    emb_layer = model.get_input_embeddings()   # nn.Embedding
    embeddings = emb_layer.weight.data         # Tensor of shape (vocab_size, emb_dim)
    return tokenizer, embeddings

def get_embedding(tokenizer, embeddings, word):
    """
    Return the embedding vector for a given single word.
    Splits on tokenizer tokens and averages if multiple tokens.
    """
    # add_prefix_space ensures subword tokens line up properly
    toks = tokenizer.encode(word, add_prefix_space=True, add_special_tokens=False)
    vecs = embeddings[toks]  # (n_subtokens, emb_dim)
    return vecs.mean(dim=0)

def find_closest(embeddings, query_vec, tokenizer, top_k=5):
    """
    Find the top_k nearest valid word tokens to the query vector.
    """
    emb_norm = F.normalize(embeddings, dim=1)
    q_norm   = F.normalize(query_vec.unsqueeze(0), dim=1)
    # cosine sim across vocab
    sims = torch.mm(q_norm, emb_norm.t()).squeeze(0)
    vals, idxs = torch.topk(sims, top_k + 20)  # a buffer for filtering
    results = []
    for score, idx in zip(vals.tolist(), idxs.tolist()):
        token = tokenizer.decode([idx]).strip()
        # filter out punctuation, empty strings, numerals, etc.
        if token.isalpha():
            results.append((token, score))
        if len(results) >= top_k:
            break
    return results

def analogy(tokenizer, embeddings, a, b, c, top_k=5):
    """
    Compute embedding[a] + embedding[b] - embedding[c], then return nearest tokens.
    """
    ea = get_embedding(tokenizer, embeddings, a)
    eb = get_embedding(tokenizer, embeddings, b)
    ec = get_embedding(tokenizer, embeddings, c)
    query = ea - eb + ec
    return find_closest(embeddings, query, tokenizer, top_k)

In [4]:
tokenizer, embeddings = load_model_and_embeddings()

In [35]:
tests = [
    ("king", "man", "woman"),
    ("man", "king", "queen"),
    ("walked", "walk", "jump"),
    ("go", "went", "run"),
    ("sang", "sing", "ring"),
    ("sing", "sang", "rang"),
]

for a, b, c in tests:
    print(f"\nAnalogy: {a} - {b} + {c}")
    print(", ".join([f"{tok}({sim:.4f})" for tok, sim in analogy(tokenizer, embeddings, a, b, c, top_k=5)]))


Analogy: king - man + woman
king(0.7758), queen(0.7085), princess(0.6046), Queen(0.5964), kings(0.5932)

Analogy: man - king + queen
man(0.6716), woman(0.6622), queen(0.5638), lady(0.4987), girl(0.4858)

Analogy: walked - walk + jump
jumped(0.7761), jump(0.7684), leapt(0.6663), jumps(0.6522), jump(0.6062)

Analogy: go - went + run
run(0.8232), go(0.5364), Run(0.5148), run(0.4910), runs(0.4873)

Analogy: sang - sing + ring
ring(0.7596), rings(0.6185), Ring(0.5679), sang(0.5266), rang(0.5237)

Analogy: sing - sang + rang
rang(0.7107), sing(0.5251), ringing(0.4633), ring(0.4274), rings(0.4176)
