In [8]:
# Enable autoreload for development
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
# Set project root and add src to path
import sys
from pathlib import Path
import os

PROJECT_ROOT = '/scratch/edk202/word2gm-fast'
project_root = Path(PROJECT_ROOT)
src_path = project_root / 'src'

if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

In [10]:
# Print resource summary
from word2gm_fast.utils.resource_summary import print_resource_summary

print_resource_summary()

<pre>SYSTEM RESOURCE SUMMARY
=============================================
Hostname: gr010.hpc.nyu.edu

Job Allocation:
   CPUs: 14
   Memory: 125.0 GB
   Partition: rtx8000
   Job ID: 63924068
   Node list: gr010

Physical GPU Hardware:
   Physical GPUs available: 1
   GPU 0: Quadro RTX 8000
      Memory: 44.0/45.0 GB (1.0 GB free)
      Temperature: 34°C
      Utilization: GPU 0%, Memory 0%

TensorFlow GPU Recognition:
   TensorFlow can access 1 GPU(s)
      /physical_device:GPU:0
   Built with CUDA support: True
=============================================</pre>

In [11]:
# Define paths for your corpus artifacts and output
dataset_artifacts_dir = (
    '/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/'
    '1850_artifacts'
)
output_dir = '/scratch/edk202/word2gm-fast/output/test_corpus'
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Set TensorBoard log directory
tensorboard_log_dir = output_dir + '/tensorboard'

In [12]:
from word2gm_fast.io.artifacts import load_pipeline_artifacts

# Load pipeline artifacts (vocab, triplets, etc.)
artifacts = load_pipeline_artifacts(dataset_artifacts_dir)
token_to_index_table = artifacts['token_to_index_table']
index_to_token_table = artifacts['index_to_token_table']
triplets_ds = artifacts['triplets_ds']
vocab_size = artifacts['vocab_size']

<span style='font-family: monospace; font-size: 120%; font-weight: normal;'>Loading pipeline artifacts from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts</span>

<span style='font-family: monospace; font-size: 120%; font-weight: normal;'>Loading token-to-index vocabulary TFRecord from:<br>&nbsp;&nbsp;/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/vocab.tfrecord</span>

<span style='font-family: monospace; font-size: 120%; font-weight: normal;'>Token-to-index lookup table created successfully.<br>Table contains 33668 tokens. Processing time: 0.28s</span>

<span style='font-family: monospace; font-size: 120%; font-weight: normal;'>Loading index-to-token vocab TFRecord from:<br>&nbsp;&nbsp;/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/1850_artifacts/vocab.tfrecord</span>

<span style='font-family: monospace; font-size: 120%; font-weight: normal;'>Index-to-token lookup table created successfully.<br>Table contains 33668 tokens. Processing time: 0.28s</span>

<span style='font-family: monospace; font-size: 120%; font-weight: normal;'>All artifacts loaded successfully!</span>

In [13]:
from word2gm_fast.utils.tf_silence import import_tf_quietly

tf = import_tf_quietly()

# Build the dataset pipeline: cache -> shuffle -> batch -> prefetch
triplets_ds = triplets_ds.cache()
BATCH_SIZE = 1024
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 10
triplets_ds = triplets_ds.shuffle(SHUFFLE_BUFFER_SIZE)
triplets_ds = triplets_ds.batch(BATCH_SIZE)
triplets_ds = triplets_ds.prefetch(tf.data.AUTOTUNE)

In [7]:
%load_ext tensorboard
%tensorboard --logdir $tensorboard_log_dir --port 6006

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 602507), started 0:00:24 ago. (Use '!kill 602507' to kill it.)

In [None]:
from word2gm_fast.training.notebook_training import run_notebook_training

# Run General-Purpose Word2GM Training
# Hardcoded stable configuration for reliable training

print(f"Starting Word2GM training...")
print(f"Vocab size: {vocab_size}")
print(f"Output: {output_dir}")

# Run training with hardcoded stable parameters
training_results = run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=vocab_size,
    embedding_size=200,           # Good balance of capacity and speed
    num_mixtures=1,               # Single Gaussian for simplicity
    spherical=True,               # Diagonal covariance
    learning_rate=0.1,            # Proven stable rate for Word2GM
    epochs=100,                   # Reasonable training duration
    adagrad=True,                 # Essential for Word2GM
    normclip=True,                # Prevents exploding gradients
    norm_cap=10.0,                # Moderate gradient clipping
    lower_sig=0.05,               # Balanced variance bounds
    upper_sig=1.5,
    wout=True,                    # Use output embeddings
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=0.5,         # Regular monitoring
    var_scale=0.05,               # Moderate regularization
    loss_epsilon=1e-8             # Numerical stability
)

print("✅ Training completed!")

Starting Word2GM training...
Vocab size: 33668
Output: /scratch/edk202/word2gm-fast/output/test_corpus



**Word2GM Training Hyperparameters:**

| Parameter         | Value                |
|-------------------|----------------------|
| Vocab size        | `33668`  |
| Embedding size    | `200` |
| Mixtures          | `1` |
| Spherical         | `True`   |
| Learning rate     | `0.1` |
| Epochs            | `50` |
| Adagrad           | `True`     |
| Normclip          | `True`    |
| Norm cap          | `10.0`    |
| Lower sigma       | `0.05`   |
| Upper sigma       | `1.5`   |
| Wout              | `True`        |
| Var scale         | `0.05`   |
| Loss epsilon      | `1e-08`|


**Epoch 1 finished. Loss:** `0.549704`  | **Duration:** `18.38` seconds.

**Epoch 2 finished. Loss:** `0.343929`  | **Duration:** `10.26` seconds.

**Epoch 3 finished. Loss:** `0.300609`  | **Duration:** `10.40` seconds.

**Epoch 4 finished. Loss:** `0.274875`  | **Duration:** `10.25` seconds.

**Epoch 5 finished. Loss:** `0.255282`  | **Duration:** `10.20` seconds.

**Epoch 6 finished. Loss:** `0.238554`  | **Duration:** `10.20` seconds.

**Epoch 7 finished. Loss:** `0.223422`  | **Duration:** `10.21` seconds.

**Epoch 8 finished. Loss:** `0.209211`  | **Duration:** `10.28` seconds.

**Epoch 9 finished. Loss:** `0.195621`  | **Duration:** `10.25` seconds.

**Epoch 10 finished. Loss:** `0.182468`  | **Duration:** `10.43` seconds.

**Epoch 11 finished. Loss:** `0.169654`  | **Duration:** `10.28` seconds.

**Epoch 12 finished. Loss:** `0.157153`  | **Duration:** `10.27` seconds.

**Epoch 13 finished. Loss:** `0.145028`  | **Duration:** `10.09` seconds.

**Epoch 14 finished. Loss:** `0.133308`  | **Duration:** `10.08` seconds.

**Epoch 15 finished. Loss:** `0.122018`  | **Duration:** `10.23` seconds.

**Epoch 16 finished. Loss:** `0.111265`  | **Duration:** `10.16` seconds.

**Epoch 17 finished. Loss:** `0.101018`  | **Duration:** `10.05` seconds.

**Epoch 18 finished. Loss:** `0.091325`  | **Duration:** `10.12` seconds.

**Epoch 19 finished. Loss:** `0.082271`  | **Duration:** `10.27` seconds.

**Epoch 20 finished. Loss:** `0.073814`  | **Duration:** `10.08` seconds.

**Epoch 21 finished. Loss:** `0.065981`  | **Duration:** `10.04` seconds.

**Epoch 22 finished. Loss:** `0.058749`  | **Duration:** `10.00` seconds.

**Epoch 23 finished. Loss:** `0.052151`  | **Duration:** `10.09` seconds.

**Epoch 24 finished. Loss:** `0.046128`  | **Duration:** `10.10` seconds.

**Epoch 25 finished. Loss:** `0.040679`  | **Duration:** `10.08` seconds.

**Epoch 26 finished. Loss:** `0.035758`  | **Duration:** `10.21` seconds.

**Epoch 27 finished. Loss:** `0.031329`  | **Duration:** `10.17` seconds.

**Epoch 28 finished. Loss:** `0.027372`  | **Duration:** `10.10` seconds.

**Epoch 29 finished. Loss:** `0.023841`  | **Duration:** `10.06` seconds.

**Epoch 30 finished. Loss:** `0.020698`  | **Duration:** `10.22` seconds.

**Epoch 31 finished. Loss:** `0.017942`  | **Duration:** `10.43` seconds.

**Epoch 32 finished. Loss:** `0.015495`  | **Duration:** `10.04` seconds.

**Epoch 33 finished. Loss:** `0.013347`  | **Duration:** `10.18` seconds.

**Epoch 34 finished. Loss:** `0.011464`  | **Duration:** `10.00` seconds.

**Epoch 35 finished. Loss:** `0.009821`  | **Duration:** `10.14` seconds.

**Epoch 36 finished. Loss:** `0.008389`  | **Duration:** `10.13` seconds.

**Epoch 37 finished. Loss:** `0.007157`  | **Duration:** `10.08` seconds.

**Epoch 38 finished. Loss:** `0.006089`  | **Duration:** `10.08` seconds.

**Epoch 39 finished. Loss:** `0.005166`  | **Duration:** `10.07` seconds.

**Epoch 40 finished. Loss:** `0.004371`  | **Duration:** `10.02` seconds.

**Epoch 41 finished. Loss:** `0.003692`  | **Duration:** `10.06` seconds.

**Epoch 42 finished. Loss:** `0.003116`  | **Duration:** `10.14` seconds.

**Epoch 43 finished. Loss:** `0.002624`  | **Duration:** `10.04` seconds.

**Epoch 44 finished. Loss:** `0.002202`  | **Duration:** `10.04` seconds.

**Epoch 45 finished. Loss:** `0.001850`  | **Duration:** `10.55` seconds.

**Epoch 46 finished. Loss:** `0.001547`  | **Duration:** `10.05` seconds.

**Epoch 47 finished. Loss:** `0.001294`  | **Duration:** `10.11` seconds.

**Epoch 48 finished. Loss:** `0.001082`  | **Duration:** `10.14` seconds.

**Epoch 49 finished. Loss:** `0.000902`  | **Duration:** `10.12` seconds.

**Epoch 50 finished. Loss:** `0.000749`  | **Duration:** `9.99` seconds.

**Training complete. Total training time:** `524.73` seconds.

✅ Training completed!
