In [21]:
# NOTE nearest neighbor analysis
# ============================================================================
# SUPERCLASS COHERENCE SCORING
# ============================================================================
# For each model:
#   - Get 5 nearest neighbors for each CIFAR-100 word
#   - Count how many neighbors are in the SAME superclass
#   - Score = total count across all 100 words
# Max possible score = 100 words √ó 5 neighbors = 500 (if all neighbors are siblings)

import torch
import numpy as np
import os
import glob
from lab6 import SkipGramModel, find_similar_words
import torchvision

# Load CIFAR-100 class names
cifar100 = torchvision.datasets.CIFAR100(root='./data', download=True)
cifar_words = set(cifar100.classes)

# Define superclass structure
superclasses = {
    'aquatic_mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
    'fish': ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
    'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
    'fruit_and_vegetables': ['apple', 'mushroom', 'orange', 'pear', 'pepper'],
    'household_electrical': ['clock', 'keyboard', 'lamp', 'telephone', 'television'],
    'household_furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
    'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
    'large_carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
    'large_outdoor_things': ['bridge', 'castle', 'house', 'road', 'skyscraper'],
    'natural_scenes': ['cloud', 'forest', 'mountain', 'plain', 'sea'],
    'large_omnivores_herbivores': ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
    'medium_mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
    'invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
    'people': ['baby', 'boy', 'girl', 'man', 'woman'],
    'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
    'small_mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
    'trees': ['maple', 'oak', 'palm', 'pine', 'willow'],
    'vehicles_1': ['bicycle', 'bus', 'motorcycle', 'pickup', 'train'],
    'vehicles_2': ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor'],
}

# Build reverse lookup: word -> set of siblings (same superclass, excluding self)
word_to_siblings = {}
for superclass, words in superclasses.items():
    for word in words:
        word_to_siblings[word] = set(words) - {word}

# Find all model files
model_files = sorted(glob.glob('EMB*.pth'))
print(f"Found {len(model_files)} models to evaluate")
print("=" * 80)

# Score each model
results = []

for model_path in model_files:
    print(f"\nEvaluating: {model_path}")
    
    # Load model
    checkpoint = torch.load(model_path)
    nodes = checkpoint['nodes']
    embedding_dim = checkpoint['embedding_dim']
    vocab_size = checkpoint['vocab_size']
    
    model = SkipGramModel(vocab_size, embedding_dim)
    model.load_state_dict(checkpoint['model_state_dict'])
    embeddings = model.get_embeddings()
    
    node_set = set(nodes)
    
    # Score: count neighbors in same superclass
    total_score = 0
    words_evaluated = 0
    
    for word in cifar_words:
        if word not in node_set:
            continue
        
        words_evaluated += 1
        neighbors = find_similar_words(word, nodes, embeddings, top_k=5)
        neighbor_words = set([w for w, s in neighbors])
        
        # Count how many neighbors are siblings (same superclass)
        siblings = word_to_siblings.get(word, set())
        sibling_neighbors = neighbor_words & siblings
        total_score += len(sibling_neighbors)
    
    # Calculate percentage (max = 5 neighbors √ó 100 words = 500)
    max_possible = words_evaluated * 5
    percentage = (total_score / max_possible * 100) if max_possible > 0 else 0
    
    results.append({
        'model': model_path,
        'score': total_score,
        'max_possible': max_possible,
        'percentage': percentage,
        'embedding_dim': embedding_dim,
    })
    
    print(f"  Score: {total_score}/{max_possible} ({percentage:.1f}%)")

print("\n" + "=" * 80)
print("EVALUATION COMPLETE")
print("=" * 80)


Found 14 models to evaluate

Evaluating: EMB128_NG10_CS2_BS64.pth
  Score: 78/500 (15.6%)

Evaluating: EMB256_NG10_CS2_BS64.pth
  Score: 80/500 (16.0%)

Evaluating: EMB32_NG10_CS2_BS64.pth
  Score: 96/500 (19.2%)

Evaluating: EMB64_NG10_CS1_BS64.pth
  Score: 30/500 (6.0%)

Evaluating: EMB64_NG10_CS2_BS128.pth
  Score: 82/500 (16.4%)

Evaluating: EMB64_NG10_CS2_BS256.pth
  Score: 46/500 (9.2%)

Evaluating: EMB64_NG10_CS2_BS32.pth
  Score: 71/500 (14.2%)

Evaluating: EMB64_NG10_CS2_BS64.pth
  Score: 90/500 (18.0%)

Evaluating: EMB64_NG10_CS3_BS64.pth
  Score: 80/500 (16.0%)

Evaluating: EMB64_NG10_CS4_BS64.pth
  Score: 78/500 (15.6%)

Evaluating: EMB64_NG15_CS2_BS64.pth
  Score: 91/500 (18.2%)

Evaluating: EMB64_NG20_CS2_BS64.pth
  Score: 89/500 (17.8%)

Evaluating: EMB64_NG20_CS3_BS32.pth
  Score: 93/500 (18.6%)

Evaluating: EMB64_NG5_CS2_BS64.pth
  Score: 67/500 (13.4%)

EVALUATION COMPLETE


In [22]:
# ============================================================================
# DISPLAY RANKINGS
# ============================================================================

import pandas as pd

# Sort by score (descending)
results_sorted = sorted(results, key=lambda x: x['score'], reverse=True)

print("=" * 80)
print("FINAL RANKINGS (by Superclass Coherence Score)")
print("=" * 80)
print(f"\n{'Rank':<6} {'Model':<35} {'Score':>8} {'Max':>6} {'%':>8}")
print("-" * 70)

for i, r in enumerate(results_sorted, 1):
    print(f"{i:<6} {r['model']:<35} {r['score']:>8} {r['max_possible']:>6} {r['percentage']:>7.1f}%")

# Highlight the winner
winner = results_sorted[0]
print(f"\n{'=' * 80}")
print(f"üèÜ BEST MODEL: {winner['model']}")
print(f"   Score: {winner['score']}/{winner['max_possible']} ({winner['percentage']:.1f}% of neighbors are superclass siblings)")
print(f"{'=' * 80}")


FINAL RANKINGS (by Superclass Coherence Score)

Rank   Model                                  Score    Max        %
----------------------------------------------------------------------
1      EMB32_NG10_CS2_BS64.pth                   96    500    19.2%
2      EMB64_NG20_CS3_BS32.pth                   93    500    18.6%
3      EMB64_NG15_CS2_BS64.pth                   91    500    18.2%
4      EMB64_NG10_CS2_BS64.pth                   90    500    18.0%
5      EMB64_NG20_CS2_BS64.pth                   89    500    17.8%
6      EMB64_NG10_CS2_BS128.pth                  82    500    16.4%
7      EMB256_NG10_CS2_BS64.pth                  80    500    16.0%
8      EMB64_NG10_CS3_BS64.pth                   80    500    16.0%
9      EMB128_NG10_CS2_BS64.pth                  78    500    15.6%
10     EMB64_NG10_CS4_BS64.pth                   78    500    15.6%
11     EMB64_NG10_CS2_BS32.pth                   71    500    14.2%
12     EMB64_NG5_CS2_BS64.pth                    67    500    13.

In [23]:
# ============================================================================
# SAVE RESULTS TO CSV
# ============================================================================

df = pd.DataFrame(results_sorted)
df['rank'] = range(1, len(df) + 1)
df = df[['rank', 'model', 'score', 'max_possible', 'percentage', 'embedding_dim']]

csv_path = 'superclass_coherence_scores.csv'
df.to_csv(csv_path, index=False)
print(f"‚úÖ Results saved to {csv_path}")
print("\n")
print(df.to_string(index=False))


‚úÖ Results saved to superclass_coherence_scores.csv


 rank                    model  score  max_possible  percentage  embedding_dim
    1  EMB32_NG10_CS2_BS64.pth     96           500        19.2             32
    2  EMB64_NG20_CS3_BS32.pth     93           500        18.6             64
    3  EMB64_NG15_CS2_BS64.pth     91           500        18.2             64
    4  EMB64_NG10_CS2_BS64.pth     90           500        18.0             64
    5  EMB64_NG20_CS2_BS64.pth     89           500        17.8             64
    6 EMB64_NG10_CS2_BS128.pth     82           500        16.4             64
    7 EMB256_NG10_CS2_BS64.pth     80           500        16.0            256
    8  EMB64_NG10_CS3_BS64.pth     80           500        16.0             64
    9 EMB128_NG10_CS2_BS64.pth     78           500        15.6            128
   10  EMB64_NG10_CS4_BS64.pth     78           500        15.6             64
   11  EMB64_NG10_CS2_BS32.pth     71           500        14.2             

In [27]:
# NEAREST NEIGHBORS FOR BEST MODEL - CIFAR100 words
# ============================================================================
# Show 5 nearest neighbors for ALL 100 CIFAR-100 words
# ============================================================================

# Use the best model from rankings (or change this to analyze a different model)
MODEL_TO_ANALYZE = winner['model']  # Best model from previous cell

print(f"Loading model: {MODEL_TO_ANALYZE}")
checkpoint = torch.load(MODEL_TO_ANALYZE)
nodes = checkpoint['nodes']
embedding_dim = checkpoint['embedding_dim']
vocab_size = checkpoint['vocab_size']

model = SkipGramModel(vocab_size, embedding_dim)
model.load_state_dict(checkpoint['model_state_dict'])
embeddings = model.get_embeddings()
node_set = set(nodes)

print("=" * 100)
print(f"NEAREST NEIGHBORS FOR ALL CIFAR-100 WORDS ({MODEL_TO_ANALYZE})")
print("=" * 100)

# Sort CIFAR words alphabetically
cifar_words_sorted = sorted(cifar_words)

for word in cifar_words_sorted:
    if word in node_set:
        neighbors = find_similar_words(word, nodes, embeddings, top_k=5)
        neighbor_str = ", ".join([f"{w} ({s:.3f})" for w, s in neighbors])
        print(f"{word:<20} ‚Üí {neighbor_str}")
    else:
        print(f"{word:<20} ‚Üí ‚ùå NOT IN VOCABULARY")

print(f"\n‚úÖ Displayed neighbors for {len([w for w in cifar_words_sorted if w in node_set])}/100 CIFAR-100 words")

Loading model: EMB32_NG10_CS2_BS64.pth
NEAREST NEIGHBORS FOR ALL CIFAR-100 WORDS (EMB32_NG10_CS2_BS64.pth)
apple                ‚Üí open (0.927), field (0.905), vase (0.903), slice (0.900), pot (0.900)
aquarium_fish        ‚Üí trout (0.941), flatfish (0.920), spider (0.888), snake (0.870), skis (0.863)
baby                 ‚Üí eyes (0.960), little (0.945), colored (0.942), head (0.942), young (0.938)
bear                 ‚Üí clouds (0.876), elephant (0.846), mouse (0.844), very (0.841), dirt (0.841)
beaver               ‚Üí skunk (0.954), possum (0.951), crocodile (0.945), raccoon (0.943), otter (0.941)
bed                  ‚Üí front (0.903), holding (0.884), right (0.879), full (0.875), for (0.873)
bee                  ‚Üí rabbit (0.918), caterpillar (0.914), lion (0.908), spider (0.908), frisbee (0.903)
beetle               ‚Üí caterpillar (0.936), cockroach (0.929), worm (0.924), spider (0.909), snail (0.898)
bicycle              ‚Üí spoon (0.957), writing (0.944), giraffes (0.937),

In [25]:
# NEAREST NEIGHBORS FOR BEST MODEL - REMAINING VG WORDS (non-CIFAR)
# ============================================================================
# Show 5 nearest neighbors for all ~455 Visual Genome words NOT in CIFAR-100
# ============================================================================

# Get all words that are NOT in CIFAR-100
vg_only_words = sorted([w for w in nodes if w not in cifar_words])

print("=" * 100)
print(f"NEAREST NEIGHBORS FOR REMAINING {len(vg_only_words)} VG WORDS ({MODEL_TO_ANALYZE})")
print("=" * 100)

for word in vg_only_words:
    neighbors = find_similar_words(word, nodes, embeddings, top_k=5)
    neighbor_str = ", ".join([f"{w} ({s:.3f})" for w, s in neighbors])
    print(f"{word:<20} ‚Üí {neighbor_str}")

print(f"\n‚úÖ Displayed neighbors for {len(vg_only_words)} Visual Genome words (non-CIFAR)")

NEAREST NEIGHBORS FOR REMAINING 423 VG WORDS (EMB32_NG10_CS2_BS64.pth)
a                    ‚Üí is (0.986), in (0.978), of (0.978), the (0.968), and (0.967)
above                ‚Üí by (0.911), behind (0.909), over (0.902), scene (0.901), around (0.900)
against              ‚Üí big (0.983), from (0.983), through (0.981), at (0.976), has (0.973)
air                  ‚Üí rock (0.895), distance (0.894), house (0.888), view (0.873), door (0.866)
airplane             ‚Üí wire (0.895), tray (0.892), tennis (0.892), middle (0.883), his (0.872)
along                ‚Üí waves (0.906), ocean (0.901), boat (0.892), beach (0.886), pair (0.881)
an                   ‚Üí sitting (0.962), brown (0.950), black (0.946), has (0.936), from (0.936)
and                  ‚Üí is (0.973), in (0.972), big (0.967), a (0.967), to (0.966)
animal               ‚Üí ears (0.933), coat (0.923), ear (0.916), fur (0.898), walking (0.891)
are                  ‚Üí in (0.968), brown (0.968), around (0.967), has (0.964), an