In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

# Set project root directory and add `src` to path
PROJECT_ROOT = '/scratch/edk202/word2gm-fast'
project_root = Path(PROJECT_ROOT)
src_path = project_root / 'src'

if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

import numpy as np
import tensorflow as tf

from word2gm_fast.models.config import Word2GMConfig
from word2gm_fast.training.notebook_training import run_notebook_training
from word2gm_fast.utils.tfrecord_io import load_pipeline_artifacts

2025-07-03 04:35:42.585720: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-03 04:35:43.624757: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751531743.848643 1820185 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751531743.895908 1820185 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751531744.421144 1820185 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [15]:
# Data loading and pipeline setup

# Define paths for your corpus artifacts and output
dataset_artifacts_dir = (
    '/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/'
    '2019_artifacts'
)
output_dir = '/scratch/edk202/word2gm-fast/output/test_corpus'
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Load pipeline artifacts (vocab, triplets, etc.)
artifacts = load_pipeline_artifacts(dataset_artifacts_dir)
vocab_table = artifacts['vocab_table']
triplets_ds = artifacts['triplets_ds']
vocab_size = artifacts['vocab_size']

# Build the dataset pipeline: cache -> shuffle -> batch -> prefetch
triplets_ds = triplets_ds.cache()
BATCH_SIZE = 1024 * 5
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 10
triplets_ds = triplets_ds.shuffle(SHUFFLE_BUFFER_SIZE)
triplets_ds = triplets_ds.batch(BATCH_SIZE)
triplets_ds = triplets_ds.prefetch(tf.data.AUTOTUNE)

print(f'Loaded vocab_size: {vocab_size}')

<pre>Loading pipeline artifacts from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/2019_artifacts</pre>

<pre>Loading vocabulary TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/2019_artifacts/vocab.tfrecord</pre>

<pre>Loading triplet TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/2019_artifacts/triplets.tfrecord</pre>

<pre>Triplet TFRecord loaded and parsed</pre>

<pre>All artifacts loaded successfully!</pre>

Loaded vocab_size: 68663


In [None]:
# Launch Word2GM training using notebook_training.py with standard hyperparameters

# Set TensorBoard log directory
tensorboard_log_dir = output_dir + '/tensorboard'

# Model configuration (only model-related fields)
# Enable asymmetrical embeddings by setting wout=True
config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=300,
    num_mixtures=2,
    spherical=True,
    norm_cap=5.0,
    lower_sig=0.05,
    upper_sig=1.0,
    var_scale=0.05,
    loss_epsilon=1e-8,
    wout=True,
    max_pe=False,  # Added for compatibility with model code
)

# Training call (training params passed directly)
run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=config.vocab_size,
    embedding_size=config.embedding_size,
    num_mixtures=config.num_mixtures,
    spherical=config.spherical,
    learning_rate=0.05,
    epochs=10,
    adagrad=True,
    normclip=True,
    norm_cap=config.norm_cap,
    lower_sig=config.lower_sig,
    upper_sig=config.upper_sig,
    var_scale=config.var_scale,
    loss_epsilon=config.loss_epsilon,
    wout=config.wout,
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=0.5,
    profile=False,
)

# --- Save model weights in Keras 3-compatible format ---
# After training, re-instantiate the model and save weights with a supported extension
from word2gm_fast.models.word2gm_model import Word2GMModel
model = Word2GMModel(config)
model.load_weights(output_dir + '/model_checkpoint')  # Load from checkpoint if needed
model.save_weights(output_dir + '/model.weights.h5')  # Save in Keras 3-compatible format


**Word2GM Training Hyperparameters:**

| Parameter         | Value                |
|-------------------|----------------------|
| Vocab size        | `68663`  |
| Embedding size    | `300` |
| Mixtures          | `2` |
| Spherical         | `True`   |
| Learning rate     | `0.05` |
| Epochs            | `10` |
| Adagrad           | `True`     |
| Normclip          | `True`    |
| Norm cap          | `5.0`    |
| Lower sigma       | `0.05`   |
| Upper sigma       | `1.0`   |
| Wout              | `True`        |
| Var scale         | `0.05`   |
| Loss epsilon      | `1e-08`|


**Epoch 1 finished. Loss:** `0.233867`  | **Duration:** `430.89` seconds.

I0000 00:00:1751479348.514202 2280680 service.cc:152] XLA service 0x14b3c8010780 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1751479348.514220 2280680 service.cc:160]   StreamExecutor device (0): Quadro RTX 8000, Compute Capability 7.5
2025-07-02 14:02:28.548085: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1751479348.668063 2280680 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1751479349.229158 2280680 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
I0000 00:00:1751479349.229158 2280680 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


**Epoch 2 finished. Loss:** `0.142469`  | **Duration:** `120.67` seconds.

2025-07-02 14:09:49.767736: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:12: Filling up shuffle buffer (this may take a while): 1 of 51200
2025-07-02 14:09:49.811948: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


**Epoch 3 finished. Loss:** `0.126184`  | **Duration:** `118.99` seconds.

**Epoch 4 finished. Loss:** `0.118138`  | **Duration:** `118.54` seconds.

**Epoch 5 finished. Loss:** `0.113170`  | **Duration:** `118.73` seconds.



**Epoch 6 finished. Loss:** `0.109738`  | **Duration:** `118.98` seconds.



**Epoch 7 finished. Loss:** `0.107190`  | **Duration:** `119.38` seconds.

**Epoch 8 finished. Loss:** `0.105201`  | **Duration:** `118.90` seconds.

**Epoch 9 finished. Loss:** `0.103589`  | **Duration:** `118.93` seconds.

**Epoch 10 finished. Loss:** `0.102241`  | **Duration:** `119.67` seconds.

**Training complete. Total training time:** `1508.64` seconds.

In [6]:
# Launch TensorBoard in the notebook
%load_ext tensorboard
%tensorboard --logdir $tensorboard_log_dir --port 6006

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 2287089), started 0:00:04 ago. (Use '!kill 2287089' to kill it.)