In [14]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

# Set project root directory and add `src` to path
PROJECT_ROOT = '/scratch/edk202/word2gm-fast'
project_root = Path(PROJECT_ROOT)
src_path = project_root / 'src'

if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

import numpy as np
import tensorflow as tf

from word2gm_fast.models.config import Word2GMConfig
from word2gm_fast.training.notebook_training import run_notebook_training
from word2gm_fast.utils.tfrecord_io import load_pipeline_artifacts

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
# Data loading and pipeline setup

# Define paths for your corpus artifacts and output
dataset_artifacts_dir = (
    '/vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/'
    '2019_artifacts'
)
output_dir = '/scratch/edk202/word2gm-fast/output/test_corpus'
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Load pipeline artifacts (vocab, triplets, etc.)
artifacts = load_pipeline_artifacts(dataset_artifacts_dir)
vocab_table = artifacts['vocab_table']
triplets_ds = artifacts['triplets_ds']
vocab_size = artifacts['vocab_size']

# Build the dataset pipeline: cache -> shuffle -> batch -> prefetch
triplets_ds = triplets_ds.cache()
BATCH_SIZE = 1024 * 5
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 10
triplets_ds = triplets_ds.shuffle(SHUFFLE_BUFFER_SIZE)
triplets_ds = triplets_ds.batch(BATCH_SIZE)
triplets_ds = triplets_ds.prefetch(tf.data.AUTOTUNE)

print(f'Loaded vocab_size: {vocab_size}')

<pre>Loading pipeline artifacts from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/2019_artifacts</pre>

<pre>Loading vocabulary TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/2019_artifacts/vocab.tfrecord</pre>

<pre>Loading triplet TFRecord from: /vast/edk202/NLP_corpora/Google_Books/20200217/eng-fiction/5gram_files/6corpus/yearly_files/data/2019_artifacts/triplets.tfrecord</pre>

<pre>Triplet TFRecord loaded and parsed</pre>

<pre>All artifacts loaded successfully!</pre>

Loaded vocab_size: 68663


In [None]:
# Launch Word2GM training using notebook_training.py with standard hyperparameters

# Set TensorBoard log directory
tensorboard_log_dir = output_dir + '/tensorboard'

# Model configuration (only model-related fields)
# Enable asymmetrical embeddings by setting wout=True
config = Word2GMConfig(
    vocab_size=vocab_size,
    embedding_size=300,
    num_mixtures=2,
    spherical=True,
    norm_cap=5.0,
    lower_sig=0.05,
    upper_sig=1.0,
    var_scale=0.05,
    loss_epsilon=1e-8,
    wout=True,
)

# Training call (training params passed directly)
run_notebook_training(
    training_dataset=triplets_ds,
    save_path=output_dir,
    vocab_size=config.vocab_size,
    embedding_size=config.embedding_size,
    num_mixtures=config.num_mixtures,
    spherical=config.spherical,
    learning_rate=0.05,
    epochs=5,
    adagrad=True,
    normclip=True,
    norm_cap=config.norm_cap,
    lower_sig=config.lower_sig,
    upper_sig=config.upper_sig,
    var_scale=config.var_scale,
    loss_epsilon=config.loss_epsilon,
    wout=config.wout,
    tensorboard_log_path=tensorboard_log_dir,
    monitor_interval=0.5,
    profile=False,
)


**Word2GM Training Hyperparameters:**

| Parameter         | Value                |
|-------------------|----------------------|
| Vocab size        | `68663`  |
| Embedding size    | `300` |
| Mixtures          | `2` |
| Spherical         | `True`   |
| Learning rate     | `0.05` |
| Epochs            | `5` |
| Adagrad           | `True`     |
| Normclip          | `True`    |
| Norm cap          | `5.0`    |
| Lower sigma       | `0.05`   |
| Upper sigma       | `1.0`   |
| Wout              | `True`        |
| Var scale         | `0.05`   |
| Loss epsilon      | `1e-08`|


**Starting epoch 1/5...**

In [13]:
# Launch TensorBoard in the notebook
%load_ext tensorboard
%tensorboard --logdir $tensorboard_log_dir --port 6006

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 1056655), started 0:11:44 ago. (Use '!kill 1056655' to kill it.)