# Word2GM Training Notebook (Clean)

This notebook provides a streamlined interface for training Word2GM models with pre-processed corpus data.

## Contents:
1. **Setup**: GPU configuration and imports
2. **Data Loading**: Load pre-processed artifacts and setup training data
3. **Training Configuration**: Multiple configuration options from conservative to aggressive
4. **Model Training**: Execute training with selected configuration
5. **Analysis**: TensorBoard visualization and nearest neighbors exploration

In [2]:
import os
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"  # Optional, may help with fragmentation

import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print(f"Could not set memory growth: {e}")

2025-07-06 19:28:37.006808: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-06 19:28:37.021832: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751844517.038945  252001 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751844517.044076  252001 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751844517.057416  252001 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [3]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

# Set project root directory and add `src` to path
PROJECT_ROOT = '/scratch/edk202/word2gm-fast'
project_root = Path(PROJECT_ROOT)
src_path = project_root / 'src'

if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

import numpy as np

from word2gm_fast.models.word2gm_model import Word2GMModel
from word2gm_fast.models.config import Word2GMConfig
from word2gm_fast.training.notebook_training import run_notebook_training
from word2gm_fast.io.artifacts import load_pipeline_artifacts
from word2gm_fast.utils.resource_summary import print_resource_summary

In [4]:
print_resource_summary()

<pre>SYSTEM RESOURCE SUMMARY
============================================================
Hostname: gv008.hpc.nyu.edu

Job Allocation:
   CPUs: 14
   Memory: 125.0 GB
   Requested partitions: v100,rtx8000,a100_2,a100_1,h100_1
   Running on: SSH failed: Host key verification failed.
   Job ID: 63438478
   Node list: gv008

GPU Information:
   CUDA GPUs detected: 1
   GPU 0: Tesla V100-PCIE-32GB
      Memory: 0.3/32.0 GB (31.7 GB free)
      Temperature: 28°C
      Utilization: GPU 0%, Memory 0%

TensorFlow GPU Detection:
   TensorFlow detects 1 GPU(s)
      /physical_device:GPU:0, Memory growth: True
   Built with CUDA: True
============================================================</pre>

In [5]:
# Define paths for your corpus artifacts and output
dataset_artifacts_dir = (
    '/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/'
    '1850_artifacts'
)
output_dir = '/scratch/edk202/word2gm-fast/output/test_corpus'
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Set TensorBoard log directory
tensorboard_log_dir = output_dir + '/tensorboard'

# Load pipeline artifacts (vocab, triplets, etc.)
artifacts = load_pipeline_artifacts(dataset_artifacts_dir)
token_to_index_table = artifacts['token_to_index_table']
index_to_token_table = artifacts['index_to_token_table']
triplets_ds = artifacts['triplets_ds']
vocab_size = artifacts['vocab_size']

# Build the dataset pipeline: cache -> shuffle -> batch -> prefetch
triplets_ds = triplets_ds.cache()
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 10
triplets_ds = triplets_ds.shuffle(SHUFFLE_BUFFER_SIZE)
triplets_ds = triplets_ds.batch(BATCH_SIZE)
triplets_ds = triplets_ds.prefetch(tf.data.AUTOTUNE)

print(f'Loaded vocab_size: {vocab_size}')

<pre>Loading pipeline artifacts from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts</pre>

<pre>Loading token-to-index vocabulary TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/vocab.tfrecord</pre>

I0000 00:00:1751844521.569451  252001 gpu_process_state.cc:208] Using CUDA malloc Async allocator for GPU: 0
I0000 00:00:1751844521.569777  252001 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31141 MB memory:  -> device: 0, name: Tesla V100-PCIE-32GB, pci bus id: 0000:2f:00.0, compute capability: 7.0
2025-07-06 19:28:41.695491: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 134217728
2025-07-06 19:28:41.954156: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-07-06 19:28:41.954156: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<pre>Loading index-to-token vocab TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/vocab.tfrecord</pre>

2025-07-06 19:28:42.259889: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<pre>Loading triplet TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/triplets.tfrecord</pre>

<pre>Triplet TFRecord loaded and parsed</pre>

<pre>All artifacts loaded successfully!</pre>

Loaded vocab_size: 33668


In [6]:
# Example: Query the token_to_index_table and index_to_token_table
test_token = 'king'
test_index = 2549

# Query token to index
token_tensor = tf.constant([test_token])
index_result = token_to_index_table.lookup(token_tensor).numpy()[0]
print(f"Index for token '{test_token}':", index_result)

# Query index to token
index_tensor = tf.constant([test_index], dtype=tf.int64)
token_result = index_to_token_table.lookup(index_tensor).numpy()[0].decode('utf-8')
print(f"Token for index {test_index}:", token_result)

Index for token 'king': 16702
Token for index 2549: beacon


In [7]:
# Print a random sample of 50 triplets from a single batch of the current corpus, showing both indices and tokens
import random

# Take a single batch from the dataset
for batch in triplets_ds.take(1):
    # If batch is a tuple of tensors (anchor, pos, neg), stack and transpose to shape (batch_size, 3)
    if isinstance(batch, tuple) and len(batch) == 3:
        anchor, pos, neg = [t.numpy() for t in batch]
        triplets_batch = list(zip(anchor, pos, neg))
    else:
        # If batch is a single tensor of shape (batch_size, 3)
        triplets_batch = batch.numpy()
    break

sample_size = min(50, len(triplets_batch))
sampled_indices = random.sample(range(len(triplets_batch)), sample_size)
sampled_triplets = [triplets_batch[i] for i in sampled_indices]

def idx_to_token(idx):
    idx_tensor = tf.constant([idx], dtype=tf.int64)
    token = index_to_token_table.lookup(idx_tensor).numpy()[0].decode('utf-8')
    return token

print(f"Random sample of {sample_size} triplets from a single batch:")
print("Idx: (anchor, pos, neg)\tTokens: (anchor, pos, neg)")
for i, triplet in enumerate(sampled_triplets):
    anchor, pos, neg = triplet
    anchor_token = idx_to_token(anchor)
    pos_token = idx_to_token(pos)
    neg_token = idx_to_token(neg)
    print(f"{i+1:2d}: ({anchor}, {pos}, {neg})\t({anchor_token}, {pos_token}, {neg_token})")

Random sample of 50 triplets from a single batch:
Idx: (anchor, pos, neg)	Tokens: (anchor, pos, neg)
 1: (7563, 31187, 24782)	(day, unexpectedly, renew)
 2: (26464, 9, 3374)	(sensation, abasement, bolivia)
 3: (21666, 8558, 17166)	(part, disdain, learnt)
 4: (2907, 27804, 3076)	(best, spare, bisson)
 5: (17622, 9020, 7021)	(little, dose, cribbage)
 6: (19129, 4279, 13933)	(might, call, hectic)
 7: (17852, 6419, 26345)	(loud, contemptuous, sediment)
 8: (2610, 10367, 15524)	(become, estrange, infusion)
 9: (14670, 23067, 23158)	(hundred, pound, precipitancy)
10: (26068, 20956, 3451)	(scene, one, bootee)
11: (26549, 20956, 20375)	(servant, one, nightingale)
12: (1186, 30188, 16174)	(another, tossed, jaded)
13: (14670, 33529, 25683)	(hundred, yard, sackville)
14: (6776, 6536, 22481)	(could, convince, pinion)
15: (14622, 3800, 21945)	(human, bring, pelt)
16: (14670, 7563, 29769)	(hundred, day, thousandfold)
17: (16784, 23164, 13450)	(know, precise, habitually)
18: (26970, 12735, 4404)	(sid

2025-07-06 19:28:46.941527: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [8]:
%load_ext tensorboard
%tensorboard --logdir $tensorboard_log_dir --port 6006

In [None]:
# Run General-Purpose Word2GM Training
# Hardcoded stable configuration for reliable training

print(f"🚀 Starting Word2GM training...")
print(f"Vocab size: {vocab_size}")
print(f"Output: {output_dir}")

# Run training with hardcoded stable parameters
training_results = run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=vocab_size,
    embedding_size=200,           # Good balance of capacity and speed
    num_mixtures=1,               # Single Gaussian for simplicity
    spherical=True,               # Diagonal covariance
    learning_rate=0.1,            # Proven stable rate for Word2GM
    epochs=30,                    # Reasonable training duration
    adagrad=True,                 # Essential for Word2GM
    normclip=True,                # Prevents exploding gradients
    norm_cap=10.0,                # Moderate gradient clipping
    lower_sig=0.05,               # Balanced variance bounds
    upper_sig=1.5,
    wout=True,                    # Use output embeddings
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=0.5,         # Regular monitoring
    var_scale=0.05,               # Moderate regularization
    loss_epsilon=1e-8             # Numerical stability
)

print("✅ Training completed!")

🚀 Starting Word2GM training...
Vocab size: 33668
Output: /scratch/edk202/word2gm-fast/output/test_corpus



**Word2GM Training Hyperparameters:**

| Parameter         | Value                |
|-------------------|----------------------|
| Vocab size        | `33668`  |
| Embedding size    | `200` |
| Mixtures          | `1` |
| Spherical         | `True`   |
| Learning rate     | `0.1` |
| Epochs            | `30` |
| Adagrad           | `True`     |
| Normclip          | `True`    |
| Norm cap          | `10.0`    |
| Lower sigma       | `0.05`   |
| Upper sigma       | `1.5`   |
| Wout              | `True`        |
| Var scale         | `0.05`   |
| Loss epsilon      | `1e-08`|


**Starting epoch 1/30...**

I0000 00:00:1751844533.612207  252419 service.cc:152] XLA service 0x148268012a20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1751844533.612230  252419 service.cc:160]   StreamExecutor device (0): Tesla V100-PCIE-32GB, Compute Capability 7.0
2025-07-06 19:28:53.618951: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1751844533.639578  252419 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1751844533.783121  252419 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
