# Word2GM Training Notebook (Clean)

This notebook provides a streamlined interface for training Word2GM models with pre-processed corpus data.

## Contents:
1. **Setup**: GPU configuration and imports
2. **Data Loading**: Load pre-processed artifacts and setup training data
3. **Training Configuration**: Multiple configuration options from conservative to aggressive
4. **Model Training**: Execute training with selected configuration
5. **Analysis**: TensorBoard visualization and nearest neighbors exploration

In [2]:
import os
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"  # Optional, may help with fragmentation

import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print(f"Could not set memory growth: {e}")

2025-07-06 13:06:48.869708: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-06 13:06:50.080054: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751821610.349422  445730 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751821610.406418  445730 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751821610.898982  445730 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [3]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

# Set project root directory and add `src` to path
PROJECT_ROOT = '/scratch/edk202/word2gm-fast'
project_root = Path(PROJECT_ROOT)
src_path = project_root / 'src'

if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

import numpy as np
import tensorflow as tf

from word2gm_fast.models.word2gm_model import Word2GMModel
from word2gm_fast.models.config import Word2GMConfig
from word2gm_fast.training.notebook_training import run_notebook_training
from word2gm_fast.io.artifacts import load_pipeline_artifacts
from word2gm_fast.utils.resource_summary import print_resource_summary

In [None]:
print_resource_summary()

In [5]:
# Define paths for your corpus artifacts and output
dataset_artifacts_dir = (
    '/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/'
    '1850_artifacts'
)
output_dir = '/scratch/edk202/word2gm-fast/output/test_corpus'
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Set TensorBoard log directory
tensorboard_log_dir = output_dir + '/tensorboard'

# Load pipeline artifacts (vocab, triplets, etc.)
artifacts = load_pipeline_artifacts(dataset_artifacts_dir)
token_to_index_table = artifacts['token_to_index_table']
index_to_token_table = artifacts['index_to_token_table']
triplets_ds = artifacts['triplets_ds']
vocab_size = artifacts['vocab_size']

# Build the dataset pipeline: cache -> shuffle -> batch -> prefetch
triplets_ds = triplets_ds.cache()
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 10
triplets_ds = triplets_ds.shuffle(SHUFFLE_BUFFER_SIZE)
triplets_ds = triplets_ds.batch(BATCH_SIZE)
triplets_ds = triplets_ds.prefetch(tf.data.AUTOTUNE)

print(f'Loaded vocab_size: {vocab_size}')

<pre>Loading pipeline artifacts from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts</pre>

<pre>Loading token-to-index vocabulary TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/vocab.tfrecord</pre>

2025-07-06 13:07:34.982362: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 134217728
2025-07-06 13:07:35.246804: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-07-06 13:07:35.246804: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<pre>Loading index-to-token vocab TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/vocab.tfrecord</pre>

2025-07-06 13:07:35.522844: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<pre>Loading triplet TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/triplets.tfrecord</pre>

<pre>Triplet TFRecord loaded and parsed</pre>

<pre>All artifacts loaded successfully!</pre>

Loaded vocab_size: 33668


In [None]:
# Example: Query the token_to_index_table and index_to_token_table
test_token = 'king'
test_index = 16702

# Query token to index
token_tensor = tf.constant([test_token])
index_result = token_to_index_table.lookup(token_tensor).numpy()[0]
print(f"Index for token '{test_token}':", index_result)

# Query index to token
index_tensor = tf.constant([test_index], dtype=tf.int64)
token_result = index_to_token_table.lookup(index_tensor).numpy()[0].decode('utf-8')
print(f"Token for index {test_index}:", token_result)

In [None]:
# Print a random sample of 50 triplets from a single batch of the current corpus, showing both indices and tokens
import random

# Take a single batch from the dataset
for batch in triplets_ds.take(1):
    # If batch is a tuple of tensors (anchor, pos, neg), stack and transpose to shape (batch_size, 3)
    if isinstance(batch, tuple) and len(batch) == 3:
        anchor, pos, neg = [t.numpy() for t in batch]
        triplets_batch = list(zip(anchor, pos, neg))
    else:
        # If batch is a single tensor of shape (batch_size, 3)
        triplets_batch = batch.numpy()
    break

sample_size = min(50, len(triplets_batch))
sampled_indices = random.sample(range(len(triplets_batch)), sample_size)
sampled_triplets = [triplets_batch[i] for i in sampled_indices]

def idx_to_token(idx):
    idx_tensor = tf.constant([idx], dtype=tf.int64)
    token = index_to_token_table.lookup(idx_tensor).numpy()[0].decode('utf-8')
    return token

print(f"Random sample of {sample_size} triplets from a single batch:")
print("Idx: (anchor, pos, neg)\tTokens: (anchor, pos, neg)")
for i, triplet in enumerate(sampled_triplets):
    anchor, pos, neg = triplet
    anchor_token = idx_to_token(anchor)
    pos_token = idx_to_token(pos)
    neg_token = idx_to_token(neg)
    print(f"{i+1:2d}: ({anchor}, {pos}, {neg})\t({anchor_token}, {pos_token}, {neg_token})")

In [None]:
config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=200,
    num_mixtures=1,
    spherical=True,
    norm_cap=20.0,         # Increased
    lower_sig=0.01,        # Lowered
    upper_sig=2.0,         # Raised
    var_scale=0.1,         # Increased
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
)

run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=config.vocab_size,
    embedding_size=config.embedding_size,
    num_mixtures=config.num_mixtures,
    spherical=config.spherical,
    learning_rate=1.0,
    epochs=30,
    adagrad=True,
    normclip=True,
    norm_cap=config.norm_cap,
    lower_sig=config.lower_sig,
    upper_sig=config.upper_sig,
    var_scale=config.var_scale,
    loss_epsilon=config.loss_epsilon,
    wout=config.wout,
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=0.5,
    profile=False,
)

## Recommended Starting Training Configurations

Here are several good starting parameter configurations for Word2GM training, ordered from conservative to aggressive:

### **Configuration 1: Conservative/Stable**
- Good for initial experiments and ensuring convergence
- Lower learning rates and tighter regularization
- Suitable for small to medium datasets

### **Configuration 2: Balanced**  
- Good middle ground for most use cases
- Moderate regularization and learning rates
- Recommended starting point for most experiments

### **Configuration 3: Aggressive**
- Higher learning rates and looser constraints
- Good for large datasets with lots of training data
- May require more careful monitoring

### **Configuration 4: Large Scale**
- Optimized for very large vocabularies (100K+ words)
- Higher embedding dimensions and adjusted regularization
- Suitable for full Google Books datasets

In [None]:
# Configuration 1: Conservative/Stable
conservative_config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=128,        # Smaller embedding for stability
    num_mixtures=1,
    spherical=True,
    norm_cap=5.0,             # Tight gradient clipping
    lower_sig=0.1,            # Conservative variance bounds
    upper_sig=1.0,
    var_scale=0.01,           # Strong regularization
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
)

# Configuration 2: Balanced (Recommended Starting Point)
balanced_config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=200,        # Good balance of capacity and speed
    num_mixtures=1,
    spherical=True,
    norm_cap=10.0,            # Moderate gradient clipping
    lower_sig=0.05,           # Balanced variance bounds
    upper_sig=1.5,
    var_scale=0.05,           # Moderate regularization
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
)

# Configuration 3: Aggressive
aggressive_config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=300,        # Higher capacity
    num_mixtures=1,
    spherical=True,
    norm_cap=20.0,            # Looser gradient clipping
    lower_sig=0.01,           # Wider variance bounds
    upper_sig=2.0,
    var_scale=0.1,            # Light regularization
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
)

# Configuration 4: Large Scale
large_scale_config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=512,        # High capacity for large datasets
    num_mixtures=1,
    spherical=True,
    norm_cap=50.0,            # Very loose clipping
    lower_sig=0.001,          # Very wide variance bounds
    upper_sig=5.0,
    var_scale=0.2,            # Minimal regularization
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
)

# Choose which configuration to use
config = balanced_config
print(f"Using configuration: embedding_size={config.embedding_size}, norm_cap={config.norm_cap}")
print(f"Variance bounds: [{config.lower_sig}, {config.upper_sig}], var_scale={config.var_scale}")

In [None]:
# Training Parameter Recommendations

def get_training_params(config_name, dataset_size="medium"):
    """
    Get recommended training parameters based on configuration and dataset size.
    
    Args:
        config_name: "conservative", "balanced", "aggressive", or "large_scale"
        dataset_size: "small", "medium", "large", or "very_large"
    """
    
    base_params = {
        "conservative": {
            "learning_rate": 0.5,
            "epochs": 15,
            "adagrad": True,
            "normclip": True,
            "monitor_interval": 1.0,
        },
        "balanced": {
            "learning_rate": 1.0,
            "epochs": 30,
            "adagrad": True,
            "normclip": True,
            "monitor_interval": 0.5,
        },
        "aggressive": {
            "learning_rate": 1.5,
            "epochs": 50,
            "adagrad": True,
            "normclip": True,
            "monitor_interval": 0.25,
        },
        "large_scale": {
            "learning_rate": 2.0,
            "epochs": 100,
            "adagrad": True,
            "normclip": True,
            "monitor_interval": 0.1,
        }
    }
    
    # Adjust for dataset size
    params = base_params[config_name].copy()
    
    if dataset_size == "small":
        params["epochs"] = max(10, params["epochs"] // 2)
        params["learning_rate"] *= 0.8
    elif dataset_size == "large":
        params["epochs"] = min(100, params["epochs"] * 2)
        params["learning_rate"] *= 1.2
    elif dataset_size == "very_large":
        params["epochs"] = min(200, params["epochs"] * 3)
        params["learning_rate"] *= 1.5
    
    return params

# Example usage - adjust based on your dataset size
dataset_size = "medium"  # Change to "small", "medium", "large", or "very_large"
config_name = "balanced"  # Change to match your chosen config

training_params = get_training_params(config_name, dataset_size)
print(f"Recommended training parameters for {config_name} config with {dataset_size} dataset:")
for key, value in training_params.items():
    print(f"  {key}: {value}")

## Parameter Explanations and Tips

### **Key Parameters to Tune:**

**Model Architecture:**
- `embedding_size`: Start with 128-200 for small datasets, 300-512 for large datasets
- `num_mixtures`: Keep at 1 initially (Gaussian mixtures add complexity)
- `spherical`: Keep True for simplicity (diagonal covariance)

**Regularization (Critical for Word2GM):**
- `norm_cap`: Gradient clipping threshold (5.0 conservative, 10.0 balanced, 20.0+ aggressive)
- `lower_sig`: Minimum variance (0.1 conservative, 0.05 balanced, 0.01 aggressive)
- `upper_sig`: Maximum variance (1.0 conservative, 1.5 balanced, 2.0+ aggressive)
- `var_scale`: Regularization strength (0.01 strong, 0.05 moderate, 0.1+ light)

**Training:**
- `learning_rate`: Start with 1.0 (Adagrad will adapt)
- `epochs`: 15-30 for initial experiments, 50-100 for final training
- `adagrad`: Always use True (essential for Word2GM)
- `normclip`: Always use True (prevents exploding gradients)

### **What to Monitor:**
- **Loss**: Should decrease steadily (watch for plateaus)
- **Gradient norms**: Should stay below norm_cap
- **Variance values**: Should stay within [lower_sig, upper_sig] bounds
- **Training speed**: ~1-2 minutes per epoch for medium datasets

### **Common Issues:**
- **Loss not decreasing**: Increase learning_rate or decrease var_scale
- **Training unstable**: Decrease norm_cap or increase regularization
- **Underfitting**: Increase embedding_size or decrease regularization
- **Overfitting**: Increase var_scale or decrease embedding_size

In [None]:
# Run Training with Selected Configuration
print("="*60)
print("TRAINING WORD2GM MODEL")
print("="*60)
print(f"Configuration: {config_name}")
print(f"Dataset size: {dataset_size}")
print(f"Vocab size: {vocab_size:,}")
print(f"Embedding size: {config.embedding_size}")
print(f"Training epochs: {training_params['epochs']}")
print(f"Learning rate: {training_params['learning_rate']}")
print("="*60)

# Run the training
run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=config.vocab_size,
    embedding_size=config.embedding_size,
    num_mixtures=config.num_mixtures,
    spherical=config.spherical,
    learning_rate=training_params['learning_rate'],
    epochs=training_params['epochs'],
    adagrad=training_params['adagrad'],
    normclip=training_params['normclip'],
    norm_cap=config.norm_cap,
    lower_sig=config.lower_sig,
    upper_sig=config.upper_sig,
    var_scale=config.var_scale,
    loss_epsilon=config.loss_epsilon,
    wout=config.wout,
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=training_params['monitor_interval'],
    profile=False,
)

print("="*60)
print("TRAINING COMPLETED!")
print("="*60)
print(f"Model weights saved to: {output_dir}")
print(f"TensorBoard logs: {tensorboard_log_dir}")
print(f"Final epoch weights: model_weights_epoch{training_params['epochs']}.weights.h5")
print("="*60)

In [None]:
%load_ext tensorboard
%tensorboard --logdir $tensorboard_log_dir --port 6006

In [None]:
# Find nearest neighbors for a given word using Word2GMModel
model = Word2GMModel(config)

# Build the model by calling it on a dummy input (tuple of three tensors)
dummy_input = (
    tf.zeros([1], dtype=tf.int32),  # word_ids
    tf.zeros([1], dtype=tf.int32),  # pos_ids
    tf.zeros([1], dtype=tf.int32),  # neg_ids
)
model(dummy_input)

# Load the trained weights
model.load_weights(output_dir + '/model_weights_epoch30.weights.h5')

# Extract vocabulary list from index_to_token_table
vocab_indices = tf.range(vocab_size, dtype=tf.int64)
vocab_tokens = index_to_token_table.lookup(vocab_indices).numpy()
vocab_list = [token.decode('utf-8') if isinstance(token, bytes) else str(token) for token in vocab_tokens]

# Choose a query word and get its index
query_word = 'good'  # Change this to any word in your vocab
try:
    query_idx = vocab_list.index(query_word)
except ValueError:
    raise ValueError(f'Word "{query_word}" not found in vocab_list.')

# Get nearest neighbor indices (returns indices, distances or a list of (index, distance) pairs)
result = model.get_nearest_neighbors(query_idx, k=10)
print("Result type:", type(result))
print("Result:", result)

# Try to unpack if possible, else treat as list of pairs
try:
    neighbor_indices, neighbor_distances = result
    neighbors = [(vocab_list[i], float(d)) for i, d in zip(neighbor_indices, neighbor_distances)]
except Exception:
    neighbors = [(vocab_list[i], float(d)) for i, d in result]

print(f'Nearest neighbors for "{query_word}":')
for word, dist in neighbors:
    print(f'{word}\t{dist:.4f}')