In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
import os

os.chdir("/content/drive/MyDrive/colab_data/projects/nlp")

!pip install --no-deps --force-reinstall /content/drive/MyDrive/colab_data/misc_dist/krk_ml_utils-0.0.1-py3-none-any.whl

Processing /content/drive/MyDrive/colab_data/misc_dist/krk_ml_utils-0.0.1-py3-none-any.whl
Installing collected packages: krk-ml-utils
  Attempting uninstall: krk-ml-utils
    Found existing installation: krk_ml_utils 0.0.1
    Uninstalling krk_ml_utils-0.0.1:
      Successfully uninstalled krk_ml_utils-0.0.1
Successfully installed krk-ml-utils-0.0.1


In [15]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("./rlm_tokenizer.json")
print(f"There are {tokenizer.get_vocab_size()} tokens in the vocab")

There are 50478 tokens in the vocab


In [16]:
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from functools import partial

#
from krk_ml_utils.transformers import Vanilla_Transformer_v2
from krk_ml_utils.datasets import JaxNLPDataLoader, NumpyDataset

# --- 1. Model & Training Hyperparameters ---
VOCAB_SIZE = tokenizer.get_vocab_size()
D_MODEL = 512
MAX_SEQ_LENGTH = 1500
#MAX_SEQ_LENGTH = 2048
NUM_LAYERS_ENC = 6
NUM_LAYERS_DEC = 6
NUM_HEADS_ENC = 8
NUM_HEADS_DEC = 8
D_FF_ENC = 2024
D_FF_DEC = 2024
DROPOUT_RATE = 0.1
#LEARNING_RATE = 1e-3
NUM_EPOCHS = 100
BATCH_SIZE = 16
PAD_TOKEN_ID = -1 # Custom token added to tokenizer
SEED = 42

LABEL_SMOOTHING_ALPHA = 0.1

WARMUP_STEPS = 60 # From the paper for the base model
#WARMUP_STEPS = 1500 # Smaller value for my dataset
ADAM_B1 = 0.9
ADAM_B2 = 0.98
ADAM_EPS = 1e-9

# --- 2. Instantiate Model, Optimizer, and Metrics ---
print("Initializing model components...")
model = Vanilla_Transformer_v2(
    vocab_size=VOCAB_SIZE, d_model=D_MODEL, max_seq_length=MAX_SEQ_LENGTH,
    num_layers_enc=NUM_LAYERS_ENC, num_layers_dec=NUM_LAYERS_DEC,
    num_heads_enc=NUM_HEADS_ENC, num_heads_dec=NUM_HEADS_DEC,
    d_dff_enc=D_FF_ENC, d_dff_dec=D_FF_DEC,
    seed=SEED, dropout_rate=DROPOUT_RATE, pad_token_id=PAD_TOKEN_ID
)

print(f"JAX sees the following devices: {jax.devices()}")

Initializing model components...
JAX sees the following devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


In [17]:
# Modify your loss function
def transformer_loss_fn_from_xy(model: nnx.Module, batch: tuple, pad_token_id: int = -1):
    source_tokens, full_target_sequence = batch
    decoder_input_tokens = full_target_sequence[:, :-1]
    labels = full_target_sequence[:, 1:]

    logits = model(
        source_tokens=source_tokens,
        target_tokens=decoder_input_tokens,
        training=True,
        #pad_token_id=pad_token_id
    )

    vocab_size = logits.shape[-1]

    # --- MODIFIED LOSS CALCULATION ---
    # 1. Create smoothed, one-hot labels
    smoothed_labels_one_hot = optax.smooth_labels(
        jax.nn.one_hot(labels, num_classes=vocab_size),
        alpha=LABEL_SMOOTHING_ALPHA
    )

    # 2. Calculate cross entropy with the smoothed labels
    loss_values = optax.softmax_cross_entropy(logits, smoothed_labels_one_hot)

    # 3. Apply padding mask (loss is now per-token, not per-logit)
    padding_mask = (labels != pad_token_id)
    masked_loss = loss_values * padding_mask

    # Normalize by the number of non-padded tokens
    mean_loss = jnp.sum(masked_loss) / jnp.sum(padding_mask)

    return mean_loss, logits

In [18]:




# For a Transformer, the primary metric is perplexity, which is exp(cross_entropy_loss).
# So, just tracking the average loss is sufficient and the most important metric.
metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
)

# --- 3. Create the Custom Loss Function ---
# We use functools.partial to "bake in" the pad_token_id.
# This makes the function signature match what train_flax_model expects.
loss_fn_with_padding = partial(transformer_loss_fn_from_xy, pad_token_id=PAD_TOKEN_ID)

In [19]:
from krk_ml_utils import datasets

# Create the Data Loader class with the correct padding token
JaxNLPDataLoader = datasets.create_jax_nlp_dataloader(
    pad_value=PAD_TOKEN_ID, max_len_targets=30, max_len_features=1000
)

# --- 4. Load the Dataset ---
train_ds = datasets.NumpyDataset(
    file_path="./house-prices-advanced-regression-techniques/train_rlm.npz",
    features_key="x",
    labels_key="y",
    rngs=None,
    allow_pickle=True,
    preload=True
)

train_loader = JaxNLPDataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True
                                #,num_workers=8
                                #,prefetch_factor=8
                                )

In [20]:
# Load the test dataset
test_ds = datasets.NumpyDataset(
    file_path="./house-prices-advanced-regression-techniques/test_rlm.npz",
    features_key="x",
    labels_key="y",
    rngs=None,
    allow_pickle=True,
    preload=True
)
test_loader = JaxNLPDataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False
                               #,num_workers = 8
                               #,prefetch_factor=8
                               )

In [21]:
print(NUM_EPOCHS * len(train_loader))

7300


In [22]:
### Learning Schedule

# Create the custom learning rate schedule from the paper
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=D_MODEL**-0.5,
    warmup_steps=WARMUP_STEPS,
    # A very long decay is similar to the paper's inverse sqrt decay
    decay_steps=NUM_EPOCHS * len(train_loader),
    end_value=0.0
)

# A more direct implementation of the paper's schedule:
def paper_lr_schedule(step: int):
    step = jnp.maximum(step, 1) # prevent step=0
    arg1 = step**-0.5
    arg2 = step * (WARMUP_STEPS**-1.5)
    return (D_MODEL**-0.5) * jnp.minimum(arg1, arg2)

# Also use the paper's beta values.
optimizer = nnx.Optimizer(model, optax.adam(
    learning_rate=paper_lr_schedule, # Use the custom schedule
    b1=ADAM_B1,
    b2=ADAM_B2,
    eps=ADAM_EPS
))

In [23]:
%%javascript
let counter = 0;
setInterval(() => {
  console.log("Background JS running:", counter++);
  document.querySelector("colab-connect-button").click();
}, 30000);  // Clicks every 30 seconds

<IPython.core.display.Javascript object>

In [24]:
from krk_ml_utils.training_v4 import train_flax_lm
checkpoint_dir = "./rlm_checkpoints_housing_data"
# --- 5. Start the Training Run ---
print("Starting training...")
trained_model, history = train_flax_lm(
    model=model,
    optimizer=optimizer,
    metrics=metrics,
    loss_fn=loss_fn_with_padding,
    train_dataloader=train_loader,
    test_dataloader=test_loader,
    num_epochs=NUM_EPOCHS,
    checkpoint_dir=checkpoint_dir,
    save_every_steps=100,
    accumulation_steps=30,
    log_train_metrics_every_steps=100,
    eval_every_steps=100,
    resume_from_checkpoint=True
)

print("Training finished!")

Starting training...
Using device mesh with 1 devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
Gradient accumulator will be initialized on first batch.
No checkpoint found, starting training from scratch.
Starting training from epoch 1 to 100
Gradient accumulator initialized with first batch gradients.
Applying leftover gradients before end-of-epoch save...
Checkpoint bundle saved to ./rlm_checkpoints_housing_data/epoch_0001_step_00000073
2025-07-10 03:09:44 | Step 100     | Epoch 2    | Train Loss: 5.1987
2025-07-10 03:11:03 | ** EVAL at Step 100     | Epoch 2    | Test Loss: 5.0440 **
Checkpoint bundle saved to ./rlm_checkpoints_housing_data/epoch_0001_step_00000100
Applying leftover gradients before end-of-epoch save...
Checkpoint bundle saved to ./rlm_checkpoints_housing_data/epoch_0002_step_00000146
2025-07-10 03:11:36 | Step 200     | Epoch 3    | Train Loss: 4.1281
2025-07-10 03:11:37 | ** EVAL at Step 200     | Epoch 3    | Test Loss: 3.8915 **
Checkp

KeyboardInterrupt: 