# Gauge Frame Semantic Analysis

Tests whether gauge frames φ encode semantic relationships.

**Hypothesis**: Semantically related tokens have similar φ (gauge frames), so transport Ω_ij ≈ I between them.

In [None]:
# Configuration - SET THESE PATHS
EXPERIMENT_CONFIG_PATH = "runs/your_experiment/experiment_config.json"  # <-- CHANGE THIS
CHECKPOINT_PATH = "runs/your_experiment/best_model.pt"                   # <-- CHANGE THIS

In [None]:
import torch
import numpy as np
import json
from pathlib import Path
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, str(Path.cwd()))

try:
    import tiktoken
    tokenizer = tiktoken.get_encoding("gpt2")
    print("Loaded GPT-2 tokenizer")
except ImportError:
    print("Install tiktoken: pip install tiktoken")
    tokenizer = None

In [None]:
# Load experiment config
config_path = Path(EXPERIMENT_CONFIG_PATH)
if config_path.exists():
    with open(config_path) as f:
        config = json.load(f)
    print(f"Loaded config: {config_path}")
    print(f"  embed_dim: {config.get('embed_dim', 'N/A')}")
    print(f"  lambda_beta: {config.get('lambda_beta', 'N/A')}")
else:
    print(f"Config not found: {config_path}")
    config = {}

In [None]:
# Load model checkpoint
checkpoint_path = Path(CHECKPOINT_PATH)
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    print(f"Loaded checkpoint: {checkpoint_path}")
    
    # Get state dict
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    
    # Find embeddings
    mu_embed = None
    phi_embed = None
    
    for key, value in state_dict.items():
        if 'mu_embed' in key and 'weight' in key:
            mu_embed = value
            print(f"  Found mu_embed: {value.shape}")
        if 'phi_embed' in key and 'weight' in key:
            phi_embed = value
            print(f"  Found phi_embed: {value.shape}")
    
    if mu_embed is None:
        print("  WARNING: No mu_embed found!")
    if phi_embed is None:
        print("  WARNING: No phi_embed found (model may use fixed gauge frames)")
else:
    print(f"Checkpoint not found: {checkpoint_path}")
    mu_embed = None
    phi_embed = None

## Test 1: Semantic Distance Comparison

Compare φ distances for related vs unrelated word pairs.

In [None]:
def get_embeddings(word):
    """Get mu and phi embeddings for a word."""
    if tokenizer is None:
        return None, None
    tokens = tokenizer.encode(word)
    if not tokens:
        return None, None
    token_id = tokens[0]
    
    mu = mu_embed[token_id] if mu_embed is not None and token_id < len(mu_embed) else None
    phi = phi_embed[token_id] if phi_embed is not None and token_id < len(phi_embed) else None
    return mu, phi

def distance(a, b):
    """Euclidean distance."""
    if a is None or b is None:
        return float('nan')
    return torch.norm(a - b).item()

In [None]:
# IMPORTANT: BPE tokens != full words!
# Let's first explore what tokens we actually have

print("=" * 60)
print("EXPLORING BPE VOCABULARY")
print("=" * 60)

# Check which words are single tokens
test_words = ["cat", "dog", "the", "and", "run", "big", "man", "day", 
              "kitten", "airplane", "democracy", "happy", "running"]

print("\nSingle-token words:")
single_tokens = {}
for word in test_words:
    tokens = tokenizer.encode(word)
    if len(tokens) == 1:
        single_tokens[word] = tokens[0]
        print(f"  '{word}' -> token {tokens[0]}")
    else:
        decoded = [tokenizer.decode([t]) for t in tokens]
        print(f"  '{word}' -> MULTI-TOKEN: {tokens} = {decoded}")

# For BPE, better to test token-level relationships:
# - Single letters (a, b, c) vs digits (0, 1, 2) vs punctuation (!, ?, .)
# - Common subwords that appear together in similar contexts

print("\n" + "=" * 60)
print("TOKEN-LEVEL ANALYSIS (more reliable for BPE)")
print("=" * 60)

# These are actual single tokens in GPT-2
# Letters (lowercase single chars are tokens 64-89 approximately)
# Digits are tokens 15-24
# Punctuation varies

# Get some example tokens
letter_tokens = []
digit_tokens = []
punct_tokens = []

for tid in range(256):  # First 256 are mostly single chars
    try:
        s = tokenizer.decode([tid])
        if len(s) == 1:
            if s.isalpha():
                letter_tokens.append((tid, s))
            elif s.isdigit():
                digit_tokens.append((tid, s))
            elif not s.isalnum() and not s.isspace():
                punct_tokens.append((tid, s))
    except:
        pass

print(f"Found {len(letter_tokens)} letter tokens, {len(digit_tokens)} digit tokens, {len(punct_tokens)} punct tokens")

In [None]:
# TOKEN-LEVEL DISTANCE ANALYSIS
# Compare: letters vs letters, letters vs digits, letters vs punctuation

print("=" * 60)
print("INTRA-CLASS vs INTER-CLASS DISTANCES")
print("=" * 60)

def get_embed_by_id(tid):
    """Get mu and phi for a token ID."""
    mu = mu_embed[tid] if mu_embed is not None and tid < len(mu_embed) else None
    phi = phi_embed[tid] if phi_embed is not None and tid < len(phi_embed) else None
    return mu, phi

# Sample some tokens from each class
import random
n_samples = min(10, len(letter_tokens), len(digit_tokens), len(punct_tokens))
sample_letters = random.sample(letter_tokens, n_samples) if len(letter_tokens) >= n_samples else letter_tokens
sample_digits = random.sample(digit_tokens, n_samples) if len(digit_tokens) >= n_samples else digit_tokens
sample_punct = random.sample(punct_tokens, n_samples) if len(punct_tokens) >= n_samples else punct_tokens

# Compute intra-class distances (letter-letter, digit-digit)
intra_mu = []
intra_phi = []
for i, (tid1, s1) in enumerate(sample_letters):
    for tid2, s2 in sample_letters[i+1:]:
        mu1, phi1 = get_embed_by_id(tid1)
        mu2, phi2 = get_embed_by_id(tid2)
        if mu1 is not None and mu2 is not None:
            intra_mu.append(distance(mu1, mu2))
        if phi1 is not None and phi2 is not None:
            intra_phi.append(distance(phi1, phi2))

# Compute inter-class distances (letter-digit, letter-punct)
inter_mu = []
inter_phi = []
for tid1, s1 in sample_letters:
    for tid2, s2 in sample_digits + sample_punct:
        mu1, phi1 = get_embed_by_id(tid1)
        mu2, phi2 = get_embed_by_id(tid2)
        if mu1 is not None and mu2 is not None:
            inter_mu.append(distance(mu1, mu2))
        if phi1 is not None and phi2 is not None:
            inter_phi.append(distance(phi1, phi2))

print(f"\nmu embeddings:")
print(f"  Intra-class (letter-letter): {np.mean(intra_mu):.4f} +/- {np.std(intra_mu):.4f}")
print(f"  Inter-class (letter-digit/punct): {np.mean(inter_mu):.4f} +/- {np.std(inter_mu):.4f}")
if intra_mu and inter_mu:
    print(f"  Ratio (inter/intra): {np.mean(inter_mu)/np.mean(intra_mu):.2f}x")

if intra_phi and inter_phi:
    print(f"\nphi embeddings (gauge frames):")
    print(f"  Intra-class (letter-letter): {np.mean(intra_phi):.4f} +/- {np.std(intra_phi):.4f}")
    print(f"  Inter-class (letter-digit/punct): {np.mean(inter_phi):.4f} +/- {np.std(inter_phi):.4f}")
    print(f"  Ratio (inter/intra): {np.mean(inter_phi)/np.mean(intra_phi):.2f}x")
    
    if np.mean(inter_phi) > np.mean(intra_phi) * 1.2:
        print(f"\n  RESULT: phi shows class structure!")
    else:
        print(f"\n  RESULT: phi does NOT show clear class structure.")

In [None]:
# Compute distances
print("=" * 60)
print("RELATED PAIRS")
print("=" * 60)

related_mu_dists = []
related_phi_dists = []

for w1, w2 in related_pairs:
    mu1, phi1 = get_embeddings(w1)
    mu2, phi2 = get_embeddings(w2)
    
    mu_d = distance(mu1, mu2)
    phi_d = distance(phi1, phi2)
    
    if not np.isnan(mu_d): related_mu_dists.append(mu_d)
    if not np.isnan(phi_d): related_phi_dists.append(phi_d)
    
    print(f"{w1:12} - {w2:12}: mu_dist={mu_d:.4f}, phi_dist={phi_d:.4f}")

print("\n" + "=" * 60)
print("UNRELATED PAIRS")
print("=" * 60)

unrelated_mu_dists = []
unrelated_phi_dists = []

for w1, w2 in unrelated_pairs:
    mu1, phi1 = get_embeddings(w1)
    mu2, phi2 = get_embeddings(w2)
    
    mu_d = distance(mu1, mu2)
    phi_d = distance(phi1, phi2)
    
    if not np.isnan(mu_d): unrelated_mu_dists.append(mu_d)
    if not np.isnan(phi_d): unrelated_phi_dists.append(phi_d)
    
    print(f"{w1:12} - {w2:12}: mu_dist={mu_d:.4f}, phi_dist={phi_d:.4f}")

In [None]:
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)

if related_mu_dists and unrelated_mu_dists:
    mu_ratio = np.mean(unrelated_mu_dists) / np.mean(related_mu_dists)
    print(f"\nmu embeddings:")
    print(f"  Related mean:   {np.mean(related_mu_dists):.4f}")
    print(f"  Unrelated mean: {np.mean(unrelated_mu_dists):.4f}")
    print(f"  Ratio:          {mu_ratio:.2f}x")

if related_phi_dists and unrelated_phi_dists:
    phi_ratio = np.mean(unrelated_phi_dists) / np.mean(related_phi_dists)
    print(f"\nphi embeddings (gauge frames):")
    print(f"  Related mean:   {np.mean(related_phi_dists):.4f}")
    print(f"  Unrelated mean: {np.mean(unrelated_phi_dists):.4f}")
    print(f"  Ratio:          {phi_ratio:.2f}x")
    
    if phi_ratio > 1.3:
        print(f"\n  RESULT: phi DOES encode semantic relationships!")
    else:
        print(f"\n  RESULT: phi does NOT clearly encode semantics.")

## Test 2: Clustering Visualization

In [None]:
from sklearn.decomposition import PCA

if phi_embed is not None and tokenizer is not None:
    # Categorize first 500 tokens
    categories = {'letters': [], 'digits': [], 'punct': [], 'space': []}
    
    for tid in range(min(500, len(phi_embed))):
        try:
            s = tokenizer.decode([tid])
            if s.strip().isalpha():
                categories['letters'].append(tid)
            elif s.strip().isdigit():
                categories['digits'].append(tid)
            elif s.strip() and not s.strip().isalnum():
                categories['punct'].append(tid)
            elif s.startswith(' '):
                categories['space'].append(tid)
        except:
            pass
    
    # PCA on phi
    if phi_embed.shape[1] >= 2:
        pca = PCA(n_components=2)
        phi_2d = pca.fit_transform(phi_embed[:500].numpy())
        
        plt.figure(figsize=(10, 8))
        colors = {'letters': 'blue', 'digits': 'red', 'punct': 'green', 'space': 'orange'}
        
        for cat, ids in categories.items():
            if ids:
                plt.scatter(phi_2d[ids, 0], phi_2d[ids, 1], 
                           c=colors[cat], label=f"{cat} (n={len(ids)})", alpha=0.6)
        
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title("PCA of phi embeddings (gauge frames)")
        plt.legend()
        plt.show()
    else:
        print("phi_embed has < 2 dimensions, skipping PCA plot")
else:
    print("No phi embeddings or tokenizer available")

In [None]:
# Same for mu embeddings
if mu_embed is not None and tokenizer is not None:
    pca = PCA(n_components=2)
    mu_2d = pca.fit_transform(mu_embed[:500].numpy())
    
    plt.figure(figsize=(10, 8))
    
    for cat, ids in categories.items():
        if ids:
            plt.scatter(mu_2d[ids, 0], mu_2d[ids, 1], 
                       c=colors[cat], label=f"{cat} (n={len(ids)})", alpha=0.6)
    
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.title("PCA of mu embeddings (belief means)")
    plt.legend()
    plt.show()

## Test 3: Transport Ablation

What happens if we set all transport operators to identity?

In [None]:
# This would require modifying the model and re-running inference.
# To test, set use_identity_transport=True in attention config.
print("To test transport ablation:")
print("1. Set use_identity_transport=True in model config")
print("2. Re-run evaluation on validation set")
print("3. Compare val PPL with vs without transport")
print("")
print("If val PPL is WORSE without transport -> transport carries information")
print("If val PPL is SAME without transport -> transport is not essential")