In [1]:
%load_ext autoreload
%autoreload 2
%repo_root

In [2]:
from scripts.analysis_utils import *
from pathlib import Path

In [3]:
T = ToricEmbeddingAnalysis(
    run_path = "runs/T512-train-window-10-coord-weights-saveall",
    data_dir = "/cache/openwebtext/train",
    device = 'cuda' if t.cuda.is_available() else 'cpu',
)

Loaded idx2ivec (input vectors) with 50304 embeddings
Loaded idx2ovec (output vectors) with 50304 embeddings
Embedding dimension: 512
Vocab size: 50304
Created word2ivec dict with 50304 entries
Created word2ovec dict with 50304 entries
Loading model from runs/T512-train-window-10-coord-weights-saveall/sgns.pt...
Loaded coord_weights: shape=torch.Size([512]), mean=0.0279

Summary (Toric):
  Embedding dimension: 512
  Vocab size: 50304
  Coord weights available: True
  Using device: cuda


In [4]:
T.get_closest_words('man')

KeyboardInterrupt: 

In [None]:
v = T.word2ivec['man'] - T.word2ivec['woman'] + T.word2ivec['queen']
T.get_closest_words(v)

[(5.800957679748535, '<UNK>'),
 (-0.6364586353302002, 'The'),
 (-0.9490953683853149, 'the'),
 (-0.9854133725166321, 'queen'),
 (-1.019601821899414, 'a'),
 (-1.0426102876663208, 'A'),
 (-1.2364208698272705, 'an'),
 (-1.262532353401184, 'old'),
 (-1.3604296445846558, 'in'),
 (-1.4040664434432983, 'king')]

In [None]:
r = RealEmbeddingAnalysis(
    run_path = "runs/R512-train-window-10",
    data_dir = "/cache/openwebtext/train",
    device = 'cuda'
)

In [None]:
r.get_closest_words('man')

In [None]:
v = r.word2ivec['man'] - r.word2ivec['woman'] + r.word2ivec['queen']
r.get_closest_words(v)

More analogies

In [None]:
# Load BATS analogy data
def load_bats_file(filepath):
    """Load a BATS analogy file. Returns list of (word1, [word2_alternatives])"""
    pairs = []
    with open(filepath) as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split('\t')
                word1 = parts[0]
                word2_alts = parts[1].split('/') if len(parts) > 1 else []
                pairs.append((word1, word2_alts))
    return pairs

def load_all_bats(bats_dir="data/BATS_3.0"):
    """Load all BATS categories into a dict"""
    bats = {}
    bats_path = Path(bats_dir)
    for category_dir in sorted(bats_path.iterdir()):
        if category_dir.is_dir():
            category_name = category_dir.name
            bats[category_name] = {}
            for file in sorted(category_dir.glob("*.txt")):
                relation_name = file.stem  # filename without extension
                bats[category_name][relation_name] = load_bats_file(file)
    return bats

bats = load_all_bats()
print("BATS categories:")
for cat, relations in bats.items():
    print(f"  {cat}: {list(relations.keys())}")


BATS categories:
  1_Inflectional_morphology: ['I01 [noun - plural_reg]', 'I02 [noun - plural_irreg]', 'I03 [adj - comparative]', 'I04 [adj - superlative]', 'I05 [verb_inf - 3pSg]', 'I06 [verb_inf - Ving]', 'I07 [verb_inf - Ved]', 'I08 [verb_Ving - 3pSg]', 'I09 [verb_Ving - Ved]', 'I10 [verb_3pSg - Ved]']
  2_Derivational_morphology: ['D01 [noun+less_reg]', 'D02 [un+adj_reg]', 'D03 [adj+ly_reg]', 'D04 [over+adj_reg]', 'D05 [adj+ness_reg]', 'D06 [re+verb_reg]', 'D07 [verb+able_reg]', 'D08 [verb+er_irreg]', 'D09 [verb+tion_irreg]', 'D10 [verb+ment_irreg]']
  3_Encyclopedic_semantics: ['E01 [country - capital]', 'E02 [country - language]', 'E03 [UK_city - county]', 'E04 [name - nationality]', 'E05 [name - occupation]', 'E06 [animal - young]', 'E07 [animal - sound]', 'E08 [animal - shelter]', 'E09 [things - color]', 'E10 [male - female]']
  4_Lexicographic_semantics: ['L01 [hypernyms - animals]', 'L02 [hypernyms - misc]', 'L03 [hyponyms - misc]', 'L04 [meronyms - substance]', 'L05 [meronym

In [None]:
# Sample data from each category
print("=== Sample pairs from each category ===\n")

for cat, relations in bats.items():
    print(f"{cat}:")
    for rel_name, pairs in list(relations.items())[:2]:  # first 2 relations per category
        print(f"  {rel_name}:")
        for word1, word2_alts in pairs[:3]:  # first 3 pairs
            print(f"    {word1} -> {word2_alts}")
    print()


=== Sample pairs from each category ===

1_Inflectional_morphology:
  I01 [noun - plural_reg]:
    album -> ['albums']
    application -> ['applications']
    area -> ['areas']
  I02 [noun - plural_irreg]:
    ability -> ['abilities']
    academy -> ['academies']
    activity -> ['activities']

2_Derivational_morphology:
  D01 [noun+less_reg]:
    arm -> ['armless']
    art -> ['artless']
    bone -> ['boneless']
  D02 [un+adj_reg]:
    able -> ['unable']
    acceptable -> ['unacceptable']
    affected -> ['unaffected']

3_Encyclopedic_semantics:
  E01 [country - capital]:
    abuja -> ['nigeria']
    amman -> ['jordan']
    ankara -> ['turkey']
  E02 [country - language]:
    andorra -> ['catalan']
    argentina -> ['spanish']
    australia -> ['english']

4_Lexicographic_semantics:
  L01 [hypernyms - animals]:
    allosaurus -> ['dinosaur', 'reptile', 'bird', 'archosaur', 'archosaurian', 'archosaurian_reptile', '']
    anaconda -> ['snake', 'reptile', 'boa', 'serpent', 'ophidian']
  

In [None]:
# Look at specific relation: male-female
male_female = bats['3_Encyclopedic_semantics']['E10 [male - female]']
print(f"Male-Female pairs ({len(male_female)} total):")
for word1, word2_alts in male_female[:10]:
    print(f"  {word1:15} -> {word2_alts}")


Male-Female pairs (50 total):
  actor           -> ['actress']
  batman          -> ['batwoman']
  boar            -> ['sow']
  boy             -> ['girl']
  brother         -> ['sister']
  buck            -> ['doe']
  bull            -> ['cow']
  businessman     -> ['businesswoman']
  chairman        -> ['chairwoman']
  dad             -> ['mom', 'mum']


In [None]:
# Test analogies: a:b :: c:d  =>  b - a + c ‚âà d
# Using first pair as reference (a, b), test on second pair (c, d)

def test_analogy(analysis, a, b, c, expected_d_list, n=10):
    """Test if b - a + c is close to any of expected_d_list"""
    try:
        v = analysis.word2ivec[b] - analysis.word2ivec[a] + analysis.word2ivec[c]
        closest = analysis.get_closest_words(v, n=n)
        closest_words = [w for _, w in closest]
        
        # Check if any expected answer is in top n
        for d in expected_d_list:
            if d in closest_words:
                rank = closest_words.index(d) + 1
                return True, rank, closest[:5]
        return False, -1, closest[:5]
    except KeyError as e:
        return None, -1, f"Missing word: {e}"

In [None]:
print("=== Analogy Tests (2 samples per category) ===\n")
print("Format: a:b :: c:? => top 5 closest to (b - a + c)\n")

for cat, relations in bats.items():
    print(f"üìÅ {cat}")
    for rel_name, pairs in list(relations.items())[:5]:  # first 2 relations
        print(f"  {rel_name}:")
        # Use first pair as reference (a, b)
        if len(pairs) < 2:
            continue
        a, b_list = pairs[0]
        b = b_list[0] if b_list else None
        if not b:
            continue
        
        # Test on next pair (c, d)
        c, d_list = pairs[1]
        
        result, rank, closest = test_analogy(R, a, b, c, d_list, 20)
        if result is None:
            print(f"    {a}:{b} :: {c}:? => {closest}")
        elif result:
            print(f"    {a}:{b} :: {c}:? => ‚úÖ found '{d_list[0]}' at rank {rank}")
            print(f"       top 5: {[w for _, w in closest]}")
        else:
            print(f"    {a}:{b} :: {c}:? => ‚ùå expected {d_list}")
            print(f"       top 5: {[w for _, w in closest]}")
    print()


=== Analogy Tests (2 samples per category) ===

Format: a:b :: c:? => top 5 closest to (b - a + c)

üìÅ 1_Inflectional_morphology
  I01 [noun - plural_reg]:
    album:albums :: application:? => ‚úÖ found 'applications' at rank 10
       top 5: ['<UNK>', 'the', 'web', 'application', 'Rails']
  I02 [noun - plural_irreg]:
    ability:abilities :: academy:? => ‚ùå expected ['academies']
       top 5: ['<UNK>', 'in', 'club', 'the', 'play']
  I03 [adj - comparative]:
    angry:angrier :: cheap:? => Missing word: 'angrier'
  I04 [adj - superlative]:
    able:ablest :: angry:? => Missing word: 'ablest'
  I05 [verb_inf - 3pSg]:
    accept:accepts :: achieve:? => ‚ùå expected ['achieves']
       top 5: ['<UNK>', 'in', 'that', 'not', 'they']

üìÅ 2_Derivational_morphology
  D01 [noun+less_reg]:
    arm:armless :: art:? => Missing word: 'armless'
  D02 [un+adj_reg]:
    able:unable :: acceptable:? => ‚ùå expected ['unacceptable']
       top 5: ['<UNK>', 'not', 'in', 'an', 'for']
  D03 [adj+ly_re