## Word2GM Visualization and Analysis

Create visualizations of the trained Word2GM embeddings, including t-SNE plots of mixture components and interactive analysis tools.

In [None]:
# Optional: Advanced Visualization (requires sklearn)
try:
    from sklearn.manifold import TSNE
    from sklearn.decomposition import PCA
    
    print("Creating t-SNE visualization of Word2GM embeddings...")
    
    # Select a subset of words for visualization (top 1000 most frequent)
    viz_size = min(1000, len(words))
    viz_word_ids = list(range(viz_size))
    viz_words = [id_to_word[i] for i in viz_word_ids if i in id_to_word]
    
    # Get embeddings for visualization
    # For Word2GM, we'll use the mixture-weighted means as representative embeddings
    viz_embeddings = []
    for word_id in viz_word_ids:
        if word_id in id_to_word:
            embedding = model.get_word_embedding(word_id)
            viz_embeddings.append(embedding)
    
    viz_embeddings = np.array(viz_embeddings)
    
    if len(viz_embeddings) > 10:  # Only proceed if we have enough words
        print(f"Computing t-SNE for {len(viz_embeddings)} words...")
        
        # First reduce dimensionality with PCA for faster t-SNE
        pca = PCA(n_components=min(50, viz_embeddings.shape[1]))
        embeddings_pca = pca.fit_transform(viz_embeddings)
        
        # t-SNE visualization
        tsne = TSNE(n_components=2, random_state=42, perplexity=30)
        embeddings_2d = tsne.fit_transform(embeddings_pca)
        
        # Plot t-SNE
        plt.figure(figsize=(14, 10))
        scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                            alpha=0.6, s=20, c=range(len(embeddings_2d)), cmap='viridis')
        
        # Annotate some words
        annotate_words = viz_words[:50] if len(viz_words) >= 50 else viz_words
        for i, word in enumerate(annotate_words):
            if i < len(embeddings_2d):
                plt.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1]), 
                           fontsize=8, alpha=0.7)
        
        plt.title('t-SNE Visualization of Word2GM Embeddings\n(Mixture-weighted means)')
        plt.xlabel('t-SNE 1')
        plt.ylabel('t-SNE 2')
        plt.colorbar(scatter)
        plt.tight_layout()
        plt.show()
        
        # If we have multiple mixture components, visualize them separately
        if config.num_mixtures > 1:
            print("Creating component-specific visualizations...")
            
            fig, axes = plt.subplots(1, config.num_mixtures, figsize=(6*config.num_mixtures, 6))
            if config.num_mixtures == 1:
                axes = [axes]
            
            for comp in range(config.num_mixtures):
                # Get component-specific embeddings
                comp_embeddings = []
                for word_id in viz_word_ids[:200]:  # Use fewer words for component viz
                    if word_id in id_to_word:
                        embedding = model.get_word_embedding(word_id, component=comp)
                        comp_embeddings.append(embedding)
                
                comp_embeddings = np.array(comp_embeddings)
                
                if len(comp_embeddings) > 10:
                    # PCA + t-SNE for this component
                    comp_pca = pca.fit_transform(comp_embeddings)
                    comp_tsne = TSNE(n_components=2, random_state=42).fit_transform(comp_pca)
                    
                    axes[comp].scatter(comp_tsne[:, 0], comp_tsne[:, 1], 
                                     alpha=0.6, s=30, c=range(len(comp_tsne)), cmap='plasma')
                    axes[comp].set_title(f'Component {comp} Embeddings')
                    axes[comp].set_xlabel('t-SNE 1')
                    axes[comp].set_ylabel('t-SNE 2')
            
            plt.tight_layout()
            plt.show()
    
    print("✓ Visualization complete")
    
except ImportError:
    print("Scikit-learn not available. Skipping t-SNE visualization.")
    print("To enable visualization, install scikit-learn: pip install scikit-learn")

# Interactive word exploration function
def explore_word(word):
    """Interactive function to explore a word's mixture components and neighbors."""
    if word not in word_to_id:
        print(f"Word '{word}' not found in vocabulary.")
        available = [w for w in word_to_id.keys() if w.startswith(word[:3])][:10]
        if available:
            print(f"Similar words available: {', '.join(available)}")
        return
    
    word_id = word_to_id[word]
    print(f"\nExploring word: '{word}' (ID: {word_id})")
    print("=" * 40)
    
    # Get mixture parameters
    mus, vars, weights = model.get_word_distributions(tf.constant([word_id]))
    mus, vars, weights = mus[0], vars[0], weights[0]
    
    print(f"Mixture weights: {weights.numpy()}")
    print(f"Number of components: {config.num_mixtures}")
    
    # Component analysis
    for comp in range(config.num_mixtures):
        print(f"\nComponent {comp} (weight: {weights[comp]:.3f}):")
        print(f"  Mean norm: {tf.norm(mus[comp]):.4f}")
        if config.spherical:
            print(f"  Variance: {vars[comp, 0]:.4f}")
        else:
            print(f"  Mean variance: {tf.reduce_mean(vars[comp]):.4f}")
        
        # Find neighbors for this component
        neighbors = find_nearest_neighbors(model, word, word_to_id, id_to_word, k=5, component=comp)
        if neighbors:
            print(f"  Nearest neighbors:")
            for i, (neighbor, score) in enumerate(neighbors):
                print(f"    {i+1}. {neighbor} ({score:.4f})")
    
    # Overall neighbors
    print(f"\nOverall nearest neighbors:")
    neighbors = find_nearest_neighbors(model, word, word_to_id, id_to_word, k=10)
    for i, (neighbor, score) in enumerate(neighbors):
        print(f"  {i+1:2d}. {neighbor} ({score:.4f})")

# Examples of interactive exploration (run these in separate cells if desired)
print("\nInteractive Word Exploration")
print("=" * 30)
print("You can explore any word using the explore_word() function.")
print("Example usage:")
print("  explore_word('bank')")
print("  explore_word('rock')")
print("  explore_word('spring')")

# Demo with a common word if available
demo_words = ['the', 'and', 'of', 'to', 'a', 'in', 'that', 'is', 'was', 'he']
demo_word = None
for word in demo_words:
    if word in word_to_id:
        demo_word = word
        break

if demo_word:
    print(f"\nDemo exploration of '{demo_word}':")
    explore_word(demo_word)

# Word2GM Training & Evaluation

**GPU-friendly TensorFlow port of Word2GM (Word to Gaussian Mixture) embeddings**

This notebook demonstrates training and evaluation of the Word2GM model - a neural embedding approach that represents each word as a Gaussian Mixture Model instead of a single point vector.

## Background

Word2GM is based on the paper ["Multimodal Word Distributions"](https://arxiv.org/abs/1704.08424) by Athiwaratkun and Wilson (ACL 2017). The key innovation is representing words as **Gaussian mixture distributions** rather than point vectors, enabling:

- **Multimodal representations**: Words like "bank" can have separate components for financial and geographical meanings
- **Uncertainty modeling**: Capture confidence and variability in word meanings
- **Richer semantic relationships**: Better capture entailment, similarity, and polysemy

## Architecture Overview

Each word `w` is represented as a Gaussian Mixture Model with `K` components:
- **Means (μ)**: `K × d` dimensional centers
- **Covariances (Σ)**: `K × d` diagonal/spherical covariances  
- **Mixture weights (π)**: `K` dimensional probability weights

**Training**: Max-margin objective using Expected Likelihood Kernel similarity between word distributions.

## Pipeline Workflow

1. **Load Training Data**: TFRecord triplets from data preparation pipeline
2. **Model Training**: Word2GM with configurable mixture components
3. **Evaluation**: Nearest neighbors, word similarity, polysemy analysis
4. **Visualization**: t-SNE plots of mixture components

## Environment Setup and GPU Configuration

Configure the environment for optimal GPU usage during Word2GM training.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import warnings
from pathlib import Path

# Setup project path
project_root = Path('/scratch/edk202/word2gm-fast')
os.chdir(project_root)
src_path = project_root / 'src'
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

# Configure TensorFlow for GPU usage with memory growth
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Reduce TF logging
# Remove CPU-only constraint to enable GPU training
if 'CUDA_VISIBLE_DEVICES' in os.environ:
    del os.environ['CUDA_VISIBLE_DEVICES']

print("Environment configured for GPU training")

## Import Required Libraries and Modules

In [None]:
# Import TensorFlow with GPU memory growth enabled
from word2gm_fast.utils.tf_silence import import_tensorflow_silently
tf = import_tensorflow_silently(gpu_memory_growth=True)

# Configure GPU memory growth for training
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if physical_devices:
    for device in physical_devices:
        tf.config.experimental.set_memory_growth(device, True)
    print(f"Configured memory growth for {len(physical_devices)} GPU(s)")

# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import json
from typing import Dict, List, Tuple, Optional

# Word2GM modules
from word2gm_fast.models.word2gm_model import Word2GMModel
from word2gm_fast.models.config import Word2GMConfig
from word2gm_fast.dataprep.tfrecord_io import load_triplets_from_tfrecord, load_vocab_from_tfrecord
from word2gm_fast.training.training_utils import train_step, log_training_metrics, summarize_dataset_pipeline

print(f"TensorFlow version: {tf.__version__}")
print("All modules imported successfully")

## Verify GPU Availability

Check for available GPUs and print device information to ensure GPU resources are accessible for training.

In [None]:
# Check GPU availability
print("GPU Device Information:")
print("=" * 50)

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    print(f"Found {len(gpus)} GPU(s):")
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu.name}")
        # Get GPU memory info if available
        try:
            gpu_details = tf.config.experimental.get_device_details(gpu)
            if 'device_name' in gpu_details:
                print(f"    Device: {gpu_details['device_name']}")
        except:
            pass
    print()
    
    # Test GPU computation
    print("Testing GPU computation...")
    with tf.device('/GPU:0'):
        a = tf.random.normal([1000, 1000])
        b = tf.random.normal([1000, 1000])
        start = time.time()
        c = tf.matmul(a, b)
        gpu_time = time.time() - start
        print(f"  GPU matrix multiply (1000x1000): {gpu_time:.4f}s")
    
    # Compare with CPU
    with tf.device('/CPU:0'):
        start = time.time()
        c_cpu = tf.matmul(a, b)
        cpu_time = time.time() - start
        print(f"  CPU matrix multiply (1000x1000): {cpu_time:.4f}s")
        print(f"  GPU speedup: {cpu_time/gpu_time:.1f}x")
else:
    print("No GPUs found. Running on CPU.")
    
print("=" * 50)

## Load Training Data

Load TFRecord artifacts generated by the data preparation pipeline for Word2GM training.

In [None]:
# Configuration - Update these paths to match your processed data
corpus_dir = "/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data"

# Choose a year with processed artifacts for training (using 1700 as it exists)
year = "1700"  # We have artifacts for this year
artifacts_dir = f"{corpus_dir}/{year}_artifacts"

print(f"Loading training data from: {artifacts_dir}")

# Verify files exist
triplets_path = f"{artifacts_dir}/triplets.tfrecord"
vocab_path = f"{artifacts_dir}/vocab.tfrecord"

if os.path.exists(triplets_path) and os.path.exists(vocab_path):
    print("✓ TFRecord files found")
    
    # Load vocabulary
    print("Loading vocabulary...")
    vocab_table = load_vocab_from_tfrecord(vocab_path)
    vocab_size = int(vocab_table.size())
    print(f"  Vocabulary size: {vocab_size:,} words")
    
    # Load training triplets
    print("Loading training triplets...")
    dataset = load_triplets_from_tfrecord(triplets_path)
    
    # Inspect dataset structure
    print("Dataset pipeline structure:")
    summarize_dataset_pipeline(dataset)
    
    # Take a sample to verify data format
    sample_batch = next(iter(dataset.batch(5)))
    word_ids, pos_ids, neg_ids = sample_batch
    print(f"\nSample batch shapes:")
    print(f"  Word IDs: {word_ids.shape}")
    print(f"  Positive IDs: {pos_ids.shape}")
    print(f"  Negative IDs: {neg_ids.shape}")
    print(f"  Sample values: {word_ids[:3].numpy()}")
    
else:
    print("❌ TFRecord files not found!")
    print("Please run the data preparation pipeline first.")
    print(f"Expected files:")
    print(f"  {triplets_path}")
    print(f"  {vocab_path}")

## Create Word2GM Model

Configure and initialize the Word2GM model with Gaussian mixture components.

In [None]:
# Model configuration (matching original Word2GM paper settings)
config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=50,        # Embedding dimension
    num_mixtures=2,           # Number of Gaussian components per word
    spherical=True,           # Use spherical (not diagonal) covariances
    learning_rate=0.05,       # Initial learning rate
    batch_size=128,           # Training batch size
    epochs_to_train=5,        # Number of training epochs (reduced for demo)
    adagrad=True,             # Use Adagrad optimizer
    var_scale=0.05,           # Variance scale for initialization
    normclip=True,            # Enable gradient/parameter clipping
    norm_cap=5.0,             # Norm clipping threshold
    lower_sig=0.05,           # Lower bound for variances
    upper_sig=1.0,            # Upper bound for variances
    wout=False                # Use separate output embeddings
)

print("Model Configuration:")
print("=" * 40)
print(f"Vocabulary size: {config.vocab_size:,}")
print(f"Embedding size: {config.embedding_size}")
print(f"Mixture components: {config.num_mixtures}")
print(f"Covariance type: {'Spherical' if config.spherical else 'Diagonal'}")
print(f"Learning rate: {config.learning_rate}")
print(f"Batch size: {config.batch_size}")
print(f"Training epochs: {config.epochs_to_train}")
print()

# Create model
model = Word2GMModel(config)

# Create optimizer
if config.adagrad:
    optimizer = tf.keras.optimizers.Adagrad(learning_rate=config.learning_rate)
else:
    optimizer = tf.keras.optimizers.SGD(learning_rate=config.learning_rate, momentum=0.9, nesterov=True)

print(f"Model created with {config.num_mixtures} mixture components per word")
print(f"Total parameters: {model.count_params():,}")
print(f"Optimizer: {'Adagrad' if config.adagrad else 'SGD'}")

# Print model summary for first few words
print(f"\nModel structure (first 3 words):")
sample_word_ids = tf.constant([0, 1, 2])
mus, vars, weights = model.get_word_distributions(sample_word_ids)
print(f"  Means shape: {mus.shape}")
print(f"  Variances shape: {vars.shape}")
print(f"  Weights shape: {weights.shape}")

## Train Word2GM Model

Train the model using GPU-accelerated operations with the max-margin objective.

In [None]:
# Prepare dataset for training
batch_size = config.batch_size
train_dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Training metrics
training_losses = []
start_time = time.time()

print("Starting Word2GM training...")
print("=" * 50)

for epoch in range(config.epochs_to_train):
    epoch_start = time.time()
    epoch_loss = 0.0
    num_batches = 0
    
    print(f"Epoch {epoch + 1}/{config.epochs_to_train}")
    
    for batch_idx, (word_ids, pos_ids, neg_ids) in enumerate(train_dataset):
        # Training step with GPU acceleration
        loss, grads = train_step(
            model, optimizer, word_ids, pos_ids, neg_ids,
            normclip=config.normclip,
            norm_cap=config.norm_cap,
            lower_sig=config.lower_sig,
            upper_sig=config.upper_sig,
            wout=config.wout
        )
        
        epoch_loss += loss
        num_batches += 1
        
        # Print progress every 100 batches
        if batch_idx % 100 == 0 and batch_idx > 0:
            avg_loss = epoch_loss / num_batches
            print(f"  Batch {batch_idx}: loss = {loss:.6f}, avg = {avg_loss:.6f}")
    
    # Epoch summary
    avg_loss = epoch_loss / max(1, num_batches)
    epoch_time = time.time() - epoch_start
    training_losses.append(float(avg_loss))
    
    print(f"  Epoch {epoch + 1} complete:")
    print(f"    Average loss: {avg_loss:.6f}")
    print(f"    Time: {epoch_time:.1f}s")
    print(f"    Batches processed: {num_batches}")
    
    # Log model statistics
    mean_mu_norm = tf.reduce_mean(tf.norm(model.mus, axis=-1))
    mean_sigma = tf.reduce_mean(tf.exp(model.logsigmas))
    print(f"    Mean μ norm: {mean_mu_norm:.4f}")
    print(f"    Mean σ: {mean_sigma:.4f}")
    print()

total_time = time.time() - start_time
print(f"Training complete! Total time: {total_time:.1f}s")
print(f"Final loss: {training_losses[-1]:.6f}")

# Plot training loss
plt.figure(figsize=(10, 6))
plt.plot(training_losses, 'b-', linewidth=2)
plt.title('Word2GM Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.show()

# Save the trained model
model_save_path = f"{artifacts_dir}/word2gm_model"
print(f"Saving trained model to: {model_save_path}")
model.save_weights(model_save_path)
print("✓ Model saved successfully")

## Evaluate Trained Model

Analyze the trained Word2GM model by examining word representations and finding nearest neighbors.

In [None]:
# Create a reverse vocabulary lookup
print("Model Evaluation")
print("=" * 40)

# Extract vocabulary as numpy arrays for analysis
vocab_keys, vocab_values = vocab_table.export()
words = [key.numpy().decode('utf-8') for key in vocab_keys]
word_ids = [int(val.numpy()) for val in vocab_values]

# Create word-to-id and id-to-word mappings
word_to_id = {word: word_id for word, word_id in zip(words, word_ids)}
id_to_word = {word_id: word for word_id, word in zip(word_ids, words)}

print(f"Vocabulary loaded: {len(words):,} words")

# Analyze mixture components for sample words
def analyze_word_mixtures(model, word_ids, id_to_word_map, num_words=10):
    """Analyze mixture components for given words."""
    if len(word_ids) > num_words:
        word_ids = word_ids[:num_words]
    
    mus, vars, weights = model.get_word_distributions(tf.constant(word_ids))
    
    print(f"\nMixture Analysis for {len(word_ids)} words:")
    for i, word_id in enumerate(word_ids):
        word = id_to_word_map.get(word_id, f"<UNK_{word_id}>")
        print(f"\nWord: '{word}' (ID: {word_id})")
        print(f"  Mixture weights: {weights[i].numpy()}")
        print(f"  Component means (first 5 dims):")
        for k in range(config.num_mixtures):
            mean_preview = mus[i, k, :5].numpy()
            var_preview = vars[i, k, :5].numpy() if not config.spherical else vars[i, k, 0].numpy()
            print(f"    Component {k}: μ={mean_preview} σ²={var_preview}")

# Function to find nearest neighbors
def find_nearest_neighbors(model, query_word, word_to_id_map, id_to_word_map, k=10, component=None):
    """Find nearest neighbors for a word using expected likelihood kernel."""
    if query_word not in word_to_id_map:
        print(f"Word '{query_word}' not found in vocabulary")
        return []
    
    query_id = word_to_id_map[query_word]
    try:
        neighbors = model.get_nearest_neighbors(query_id, k=k, component=component)
        result = []
        for neighbor_id, score in neighbors:
            neighbor_word = id_to_word_map.get(neighbor_id, f"<UNK_{neighbor_id}>")
            result.append((neighbor_word, score))
        return result
    except Exception as e:
        print(f"Error finding neighbors: {e}")
        return []

# Analyze first 5 words
sample_word_ids = list(range(min(5, len(words))))
analyze_word_mixtures(model, sample_word_ids, id_to_word)

# Example words for polysemy analysis (if they exist in vocabulary)
example_words = ['bank', 'rock', 'spring', 'light', 'star', 'plant', 'left', 'right']
existing_examples = [word for word in example_words if word in word_to_id]

if existing_examples:
    print(f"\nNearest Neighbor Analysis for Example Words:")
    print("=" * 50)
    
    for word in existing_examples[:3]:  # Analyze first 3 existing examples
        print(f"\nWord: '{word}'")
        print("-" * 20)
        
        # Overall nearest neighbors
        neighbors = find_nearest_neighbors(model, word, word_to_id, id_to_word, k=10)
        if neighbors:
            print("Overall nearest neighbors:")
            for i, (neighbor, score) in enumerate(neighbors):
                print(f"  {i+1:2d}. {neighbor} ({score:.4f})")
        
        # Component-specific neighbors (if multiple components)
        if config.num_mixtures > 1:
            for comp in range(config.num_mixtures):
                comp_neighbors = find_nearest_neighbors(model, word, word_to_id, id_to_word, k=5, component=comp)
                if comp_neighbors:
                    print(f"Component {comp} neighbors:")
                    for i, (neighbor, score) in enumerate(comp_neighbors):
                        print(f"  {i+1}. {neighbor} ({score:.4f})")

# Examine parameter distributions
print(f"\nModel Parameter Statistics:")
print(f"=" * 30)

# Means statistics
mu_norms = tf.norm(model.mus, axis=-1)  # [vocab_size, num_mixtures]
print(f"Mean norms:")
print(f"  Min: {tf.reduce_min(mu_norms):.4f}")
print(f"  Max: {tf.reduce_max(mu_norms):.4f}")
print(f"  Mean: {tf.reduce_mean(mu_norms):.4f}")
print(f"  Std: {tf.math.reduce_std(mu_norms):.4f}")

# Variance statistics
sigmas = tf.exp(model.logsigmas)
print(f"Variances:")
print(f"  Min: {tf.reduce_min(sigmas):.4f}")
print(f"  Max: {tf.reduce_max(sigmas):.4f}")
print(f"  Mean: {tf.reduce_mean(sigmas):.4f}")

# Mixture weights statistics
mixture_probs = tf.nn.softmax(model.mixture, axis=-1)
print(f"Mixture weights:")
print(f"  Min: {tf.reduce_min(mixture_probs):.4f}")
print(f"  Max: {tf.reduce_max(mixture_probs):.4f}")
print(f"  Mean: {tf.reduce_mean(mixture_probs):.4f}")

# Plot parameter distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Mean norms histogram
axes[0,0].hist(mu_norms.numpy().flatten(), bins=50, alpha=0.7, color='blue')
axes[0,0].set_title('Distribution of Mean Norms')
axes[0,0].set_xlabel('Norm')
axes[0,0].set_ylabel('Frequency')

# Variance histogram
axes[0,1].hist(sigmas.numpy().flatten(), bins=50, alpha=0.7, color='green')
axes[0,1].set_title('Distribution of Variances')
axes[0,1].set_xlabel('Variance')
axes[0,1].set_ylabel('Frequency')

# Mixture weights histogram
axes[1,0].hist(mixture_probs.numpy().flatten(), bins=50, alpha=0.7, color='red')
axes[1,0].set_title('Distribution of Mixture Weights')
axes[1,0].set_xlabel('Weight')
axes[1,0].set_ylabel('Frequency')

# Training loss
axes[1,1].plot(training_losses, 'b-', linewidth=2)
axes[1,1].set_title('Training Loss')
axes[1,1].set_xlabel('Epoch')
axes[1,1].set_ylabel('Loss')
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n" + "="*50)
print("Training and evaluation complete!")
print(f"Model saved to: {model_save_path}")
print("You can now use the trained Word2GM model for downstream tasks.")