In [5]:
import os
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"  # Optional, may help with fragmentation

import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print(f"Could not set memory growth: {e}")

In [6]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

# Set project root directory and add `src` to path
PROJECT_ROOT = '/scratch/edk202/word2gm-fast'
project_root = Path(PROJECT_ROOT)
src_path = project_root / 'src'

if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

import numpy as np
import tensorflow as tf

from word2gm_fast.models.word2gm_model import Word2GMModel
from word2gm_fast.models.config import Word2GMConfig
from word2gm_fast.training.notebook_training import run_notebook_training
from word2gm_fast.utils.tfrecord_io import load_pipeline_artifacts
from word2gm_fast.utils.resource_summary import print_resource_summary

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
print_resource_summary()

<pre>SYSTEM RESOURCE SUMMARY
============================================================
Hostname: cm020.hpc.nyu.edu

Job Allocation:
   CPUs: 14
   Memory: 125.0 GB
   Requested partitions: short,cs,cm,cpu_a100_2,cpu_a100_1,cpu_gpu
   Running on: SSH failed: Host key verification failed.
   Job ID: 63372384
   Node list: cm020

GPU Information:
   Error: NVML Shared Library Not Found

TensorFlow GPU Detection:
   TensorFlow detects 0 GPU(s)
   Built with CUDA: True
============================================================</pre>

In [10]:
# Data loading and pipeline setup
from word2gm_fast.utils.tfrecord_io import load_pipeline_artifacts

# Define paths for your corpus artifacts and output
dataset_artifacts_dir = (
    '/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/'
    '1940_artifacts'
)
output_dir = '/scratch/edk202/word2gm-fast/output/test_corpus'
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Set TensorBoard log directory
tensorboard_log_dir = output_dir + '/tensorboard'

# Load pipeline artifacts (vocab, triplets, etc.)
artifacts = load_pipeline_artifacts(dataset_artifacts_dir)
token_to_index_table = artifacts['token_to_index_table']
index_to_token_table = artifacts['index_to_token_table']
triplets_ds = artifacts['triplets_ds']
vocab_size = artifacts['vocab_size']

# Build the dataset pipeline: cache -> shuffle -> batch -> prefetch
triplets_ds = triplets_ds.cache()
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 10
triplets_ds = triplets_ds.shuffle(SHUFFLE_BUFFER_SIZE)
triplets_ds = triplets_ds.batch(BATCH_SIZE)
triplets_ds = triplets_ds.prefetch(tf.data.AUTOTUNE)

print(f'Loaded vocab_size: {vocab_size}')

<pre>Loading pipeline artifacts from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1940_artifacts</pre>

<pre>Loading token-to-index vocabulary TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1940_artifacts/vocab.tfrecord</pre>

2025-07-04 04:32:51.185089: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<pre>Loading index-to-token vocab TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1940_artifacts/vocab.tfrecord</pre>

<pre>Loading triplet TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1940_artifacts/triplets.tfrecord</pre>

<pre>Triplet TFRecord loaded and parsed</pre>

<pre>All artifacts loaded successfully!</pre>

Loaded vocab_size: 42401


In [16]:
# Example: Query the token_to_index_table and index_to_token_table
test_token = 'king'
test_index = 20794

# Query token to index
token_tensor = tf.constant([test_token])
index_result = token_to_index_table.lookup(token_tensor).numpy()[0]
print(f"Index for token '{test_token}':", index_result)

# Query index to token
index_tensor = tf.constant([test_index], dtype=tf.int64)
token_result = index_to_token_table.lookup(index_tensor).numpy()[0].decode('utf-8')
print(f"Token for index {test_index}:", token_result)

Index for token 'king': 20794
Token for index 20794: king


In [None]:
# Print a random sample of 50 triplets from a single batch of the current corpus, showing both indices and tokens
import random

# Take a single batch from the dataset
for batch in triplets_ds.take(1):
    # If batch is a tuple of tensors (anchor, pos, neg), stack and transpose to shape (batch_size, 3)
    if isinstance(batch, tuple) and len(batch) == 3:
        anchor, pos, neg = [t.numpy() for t in batch]
        triplets_batch = list(zip(anchor, pos, neg))
    else:
        # If batch is a single tensor of shape (batch_size, 3)
        triplets_batch = batch.numpy()
    break

sample_size = min(50, len(triplets_batch))
sampled_indices = random.sample(range(len(triplets_batch)), sample_size)
sampled_triplets = [triplets_batch[i] for i in sampled_indices]

def idx_to_token(idx):
    idx_tensor = tf.constant([idx], dtype=tf.int64)
    token = index_to_token_table.lookup(idx_tensor).numpy()[0].decode('utf-8')
    return token

print(f"Random sample of {sample_size} triplets from a single batch:")
print("Idx: (anchor, pos, neg)\tTokens: (anchor, pos, neg)")
for i, triplet in enumerate(sampled_triplets):
    anchor, pos, neg = triplet
    anchor_token = idx_to_token(anchor)
    pos_token = idx_to_token(pos)
    neg_token = idx_to_token(neg)
    print(f"{i+1:2d}: ({anchor}, {pos}, {neg})\t({anchor_token}, {pos_token}, {neg_token})")

Random sample of 50 triplets (anchor, pos, neg) from a single batch:
 1: (anchor=319, pos=16745, neg=1102)
 2: (anchor=1458, pos=22267, neg=14304)
 3: (anchor=1566, pos=26393, neg=32382)
 4: (anchor=1932, pos=25375, neg=16498)
 5: (anchor=3709, pos=26393, neg=35175)
 6: (anchor=1463, pos=26393, neg=11551)
 7: (anchor=745, pos=26393, neg=3275)
 8: (anchor=1458, pos=33724, neg=19295)
 9: (anchor=1458, pos=26393, neg=32624)
10: (anchor=3709, pos=26393, neg=27943)
11: (anchor=1884, pos=21884, neg=4281)
12: (anchor=3816, pos=41263, neg=39186)
13: (anchor=1884, pos=26393, neg=2283)
14: (anchor=1954, pos=26393, neg=20039)
15: (anchor=3372, pos=26393, neg=26668)
16: (anchor=1884, pos=26393, neg=10744)
17: (anchor=4051, pos=26393, neg=18059)
18: (anchor=2638, pos=26393, neg=28529)
19: (anchor=756, pos=26393, neg=26104)
20: (anchor=1458, pos=26393, neg=39008)
21: (anchor=1458, pos=26393, neg=15866)
22: (anchor=1884, pos=33760, neg=25919)
23: (anchor=1932, pos=42000, neg=30216)
24: (anchor=1458, 

2025-07-04 04:39:55.621276: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [13]:
config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=200,
    num_mixtures=1,
    spherical=True,
    norm_cap=20.0,         # Increased
    lower_sig=0.01,        # Lowered
    upper_sig=2.0,         # Raised
    var_scale=0.1,         # Increased
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
)

run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=config.vocab_size,
    embedding_size=config.embedding_size,
    num_mixtures=config.num_mixtures,
    spherical=config.spherical,
    learning_rate=1.0,
    epochs=30,
    adagrad=True,
    normclip=True,
    norm_cap=config.norm_cap,
    lower_sig=config.lower_sig,
    upper_sig=config.upper_sig,
    var_scale=config.var_scale,
    loss_epsilon=config.loss_epsilon,
    wout=config.wout,
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=0.5,
    profile=False,
)

In [8]:
# Train on the small dataset to test for overfitting
overfit_config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=100,  # smaller for faster overfit
    num_mixtures=1,
    spherical=True,
    norm_cap=100.0,      # very loose
    lower_sig=1e-4,      # very loose
    upper_sig=10.0,      # very loose
    var_scale=0.0001,       # no regularization
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,
    # other params as needed
 )
run_notebook_training(
    training_dataset=small_triplets_ds,
    save_path=output_dir + '/overfit',
    vocab_size=overfit_config.vocab_size,
    embedding_size=overfit_config.embedding_size,
    num_mixtures=overfit_config.num_mixtures,
    spherical=overfit_config.spherical,
    learning_rate=1.0,
    epochs=100,
    adagrad=True,
    normclip=False,
    norm_cap=overfit_config.norm_cap,
    lower_sig=overfit_config.lower_sig,
    upper_sig=overfit_config.upper_sig,
    var_scale=overfit_config.var_scale,
    loss_epsilon=overfit_config.loss_epsilon,
    wout=overfit_config.wout,
    tensorboard_log_path=output_dir + '/overfit/tensorboard',
    monitor_interval=0.2,
    profile=False,
 )

In [9]:
# Launch TensorBoard in the notebook
%load_ext tensorboard
%tensorboard --logdir $tensorboard_log_dir --port 6006

In [17]:
# Find nearest neighbors for a given word using Word2GMModel and vocab_list
model = Word2GMModel(config)

# Build the model by calling it on a dummy input (tuple of three tensors)
dummy_input = (
    tf.zeros([1], dtype=tf.int32),  # word_ids
    tf.zeros([1], dtype=tf.int32),  # pos_ids
    tf.zeros([1], dtype=tf.int32),  # neg_ids
)
model(dummy_input)

model.load_weights(output_dir + '/model_weights_epoch30.weights.h5')

# Choose a query word and get its index
query_word = 'good'  # Change this to any word in your vocab
try:
    query_idx = vocab_list.index(query_word)
except ValueError:
    raise ValueError(f'Word "{query_word}" not found in vocab_list.')

# Get nearest neighbor indices (returns indices, distances or a list of (index, distance) pairs)
result = model.get_nearest_neighbors(query_idx, k=10)
print("Result type:", type(result))
print("Result:", result)

# Try to unpack if possible, else treat as list of pairs
try:
    neighbor_indices, neighbor_distances = result
    neighbors = [(vocab_list[i], float(d)) for i, d in zip(neighbor_indices, neighbor_distances)]
except Exception:
    neighbors = [(vocab_list[i], float(d)) for i, d in result]

print(f'Nearest neighbors for "{query_word}":')
for word, dist in neighbors:
    print(f'{word}\t{dist:.4f}')