In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

# Configuration for all supported models
MODEL_CONFIGS = {
    "gpt2": {
        "model_name": "gpt2",
        "tokenizer_name": "gpt2",
    },
    "pythia1.4b": {
        "model_name": "EleutherAI/pythia-1.4b-v0",
        "tokenizer_name": "EleutherAI/pythia-1.4b-v0",
    },
    "gemma2b": {
        "model_name": "google/gemma-2-2b",
        "tokenizer_name": "google/gemma-2-2b",
    },
    "qwen2": {
        "model_name": "Qwen/Qwen2.5-1.5B-Instruct",
        "tokenizer_name": "Qwen/Qwen2.5-1.5B-Instruct",
    },
    "bert-base-uncased": {
        "model_name": "bert-base-uncased",
        "tokenizer_name": "bert-base-uncased",
    },
    "bert-large-uncased": {
        "model_name": "bert-large-uncased",
        "tokenizer_name": "bert-large-uncased",
    },
    "distilbert-base-uncased": {
        "model_name": "distilbert-base-uncased",
        "tokenizer_name": "distilbert-base-uncased",
    },
}

def load_model_and_embeddings(model_key: str):
    if model_key not in MODEL_CONFIGS:
        valid = ", ".join(MODEL_CONFIGS.keys())
        raise ValueError(f"Unknown model '{model_key}'. Valid keys: {valid}")
    cfg = MODEL_CONFIGS[model_key]
    tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_name"])
    model = AutoModel.from_pretrained(cfg["model_name"])
    emb_layer = model.get_input_embeddings()
    return tokenizer, emb_layer.weight.data

def get_embedding(tokenizer, embeddings, word: str):
    tokens = tokenizer.tokenize(word, add_special_tokens=False)
    ids = tokenizer.convert_tokens_to_ids(tokens)
    vecs = embeddings[ids]
    return vecs.mean(dim=0)

def find_closest(embeddings, query_vec, tokenizer, top_k=5):
    emb_norm = F.normalize(embeddings, dim=1)
    q_norm = F.normalize(query_vec.unsqueeze(0), dim=1)
    sims = torch.mm(q_norm, emb_norm.t()).squeeze(0)
    vals, idxs = torch.topk(sims, top_k + 20)
    results = []
    for score, idx in zip(vals.tolist(), idxs.tolist()):
        tok = tokenizer.decode([idx]).strip()
        if tok.isalpha():
            results.append((tok, score))
        if len(results) >= top_k:
            break
    return results

def analogy_a_minus_b_plus_c(tokenizer, embeddings, a, b, c, top_k=5):
    va = get_embedding(tokenizer, embeddings, a)
    vb = get_embedding(tokenizer, embeddings, b)
    vc = get_embedding(tokenizer, embeddings, c)
    query = va - vb + vc
    return find_closest(embeddings, query, tokenizer, top_k)

In [19]:
model_key = "qwen2"

tokenizer, embeddings = load_model_and_embeddings(model_key)
dim = embeddings.shape[1]
vocab_size = embeddings.shape[0]
print(f"Loaded '{model_key}': embedding dim = {dim}, vocab size = {vocab_size}")

Loaded 'qwen2': embedding dim = 1536, vocab size = 151936


In [20]:
tests = [
    ("king", "man", "woman"),
    ("man", "king", "queen"),
    ("walked", "walk", "jump"),
    ("go", "went", "run"),
    ("sang", "sing", "ring"),
    ("sing", "sang", "rang"),
]

for a, b, c in tests:
    print(f"\nAnalogy ({a} - {b} + {c}):")
    for tok, sim in analogy_a_minus_b_plus_c(tokenizer, embeddings, a, b, c, top_k=5):
        print(f"  {tok:<12} cosine_sim={sim:.4f}")


Analogy (king - man + woman):
  king         cosine_sim=0.6939
  KING         cosine_sim=0.4254
  woman        cosine_sim=0.4185
  King         cosine_sim=0.3925
  queen        cosine_sim=0.3881

Analogy (man - king + queen):
  man          cosine_sim=0.6176
  MAN          cosine_sim=0.4735
  queen        cosine_sim=0.4342
  Man          cosine_sim=0.4330
  man          cosine_sim=0.4305

Analogy (walked - walk + jump):
  jump         cosine_sim=0.7297
  jump         cosine_sim=0.5977
  Jump         cosine_sim=0.5395
  Jump         cosine_sim=0.5100
  jumping      cosine_sim=0.5046

Analogy (go - went + run):
  run          cosine_sim=0.6994
  run          cosine_sim=0.5766
  Run          cosine_sim=0.5133
  Run          cosine_sim=0.4999
  go           cosine_sim=0.4472

Analogy (sang - sing + ring):
  ring         cosine_sim=0.5789
  ring         cosine_sim=0.4186
  Ring         cosine_sim=0.3539
  ang          cosine_sim=0.3486
  Ring         cosine_sim=0.3337

Analogy (sing - sang