Skip to content

Feature/skip gram#33

Merged
modiase merged 18 commits intomainfrom
feature/skip-gram
Sep 16, 2025
Merged

Feature/skip gram#33
modiase merged 18 commits intomainfrom
feature/skip-gram

Conversation

@modiase
Copy link
Owner

@modiase modiase commented Jul 31, 2025

This change is Reviewable

Summary by CodeRabbit

  • New Features

    • Full Word2Vec sample: CBOW/Skip‑gram/predict workflows, train/sample CLI, vocabulary utilities, ZIP-based save/load, and executable launchers.
  • Improvements

    • Output-layer backward pass accepts negative samples; layer error messages include layer context; CLI options and messaging standardized to "optimiser".
  • Bug Fixes

    • Corrected optimiser import/usage inconsistencies across training paths.
  • Chores

    • Track .envrc, add typecheck script, add mnist/word2vec launchers, and update development environment configuration.

@coderabbitai
Copy link

coderabbitai bot commented Jul 31, 2025

Warning

Rate limit exceeded

@modiase has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 16 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between b82d25e and 370475c.

📒 Files selected for processing (7)
  • bin/mnist (1 hunks)
  • bin/word2vec (1 hunks)
  • flake.nix (2 hunks)
  • mo_net/model/layer/output.py (1 hunks)
  • mo_net/samples/word2vec/__main__.py (1 hunks)
  • mo_net/samples/word2vec/vocab.py (1 hunks)
  • mo_net/train/trainer/trainer.py (15 hunks)

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Renamed "optimizer" to British "optimiser" across CLI, trainer, optimiser modules and samples; added Word2Vec samples and vocab utilities; introduced EmbeddingWeightDecayRegulariser; changed Model.backward_prop to return D[Activations]; extended output layer backward to accept negative labels; minor layer error message and tooling updates.

Changes

Cohort / File(s) Change Summary
Optimiser naming
mo_net/cli.py, mo_net/scripts/validation.py, mo_net/samples/mnist/cnn.py, mo_net/optimiser/adam.py, mo_net/optimiser/rmsprop.py, mo_net/train/trainer/trainer.py, mo_net/train/trainer/parallel.py
British spelling applied: optimizeroptimiser across imports, function/parameter names, attributes, CLI flags, help text, and internal references; import paths adjusted to mo_net.optimiser.
Word2Vec — models & CLI
mo_net/samples/word2vec/__main__.py, bin/word2vec
Added CBOWModel, SkipGramModel, PredictModel, new CLI commands (train, sample), training pipeline (TrainingRun/SqliteBackend/BasicTrainer), ZIP artifact I/O, and a CLI launcher script.
Vocabulary module
mo_net/samples/word2vec/vocab.py
New Vocab dataclass and utilities: tokenization, serialization/deserialization, english sentence retrieval, stop words, and get_training_set producing JAX arrays for contexts/targets.
Regulariser: embedding weight decay
mo_net/regulariser/weight_decay.py
Added HasEmbeddingLayer protocol and EmbeddingWeightDecayRegulariser implementing TrainingStepHandler; attaches to optimiser and applies embedding-specific weight decay; updated WeightDecayRegulariser.attach parameter name to optimiser.
Model / layer API updates
mo_net/model/base.py, mo_net/model/layer/output.py, mo_net/model/layer/base.py
ModelBase.backward_prop signature changed to return D[Activations]; SparseCategoricalSoftmaxOutputLayer gained backward_prop_with_negative and optional Y_negative handling with shape validation/broadcasting; base layer forward error now includes self in message.
Trainer internals
mo_net/train/trainer/trainer.py, mo_net/train/trainer/parallel.py
Trainer updated to handle SIGINT with interactive prompt; internal and public API/parameters renamed to use optimiser; internal attribute _optimizer_optimiser and callers updated.
Samples / scripts / tooling
mo_net/samples/mnist/cnn.py, mo_net/samples/word2vec/*, .envrc, .gitignore, bin/typecheck, bin/mnist, flake.nix
Sample code updated to use get_optimiser; new sample tooling and launcher scripts added; .envrc tracked (removed from .gitignore); flake.nix adjusted for NVIDIA X11 and LD_LIBRARY_PATH; new bin/typecheck, bin/mnist.

Sequence Diagram(s)

High-level Word2Vec training (CBOW / SkipGram)

sequenceDiagram
    participant CLI
    participant Vocab
    participant Model as "Model (CBOW / SkipGram)"
    participant Trainer
    participant Optimiser as "Optimiser"

    CLI->>Vocab: build or load vocab & tokenized data
    CLI->>Model: create or load model (cbow/skipgram)
    CLI->>Optimiser: get_optimiser(...)
    CLI->>Trainer: start trainer with model + optimiser
    loop each epoch
        Trainer->>Model: forward_pass(X)
        Model-->>Trainer: outputs / loss
        Trainer->>Model: backward_prop(Y_true[, Y_negative])
        Model-->>Trainer: gradients (D[Activations])
        Trainer->>Optimiser: compute_update(gradients)
        Optimiser->>Model: apply_update
    end
Loading

Embedding weight-decay handler attachment

sequenceDiagram
    participant Model
    participant EmbeddingReg as "EmbeddingWeightDecayRegulariser"
    participant Optimiser as "Optimiser"

    Model->>EmbeddingReg: instantiate (embedding layer)
    EmbeddingReg->>Optimiser: attach(after_compute_update handler)
    loop training steps
        Optimiser->>EmbeddingReg: after_compute_update(learning_rate)
        EmbeddingReg->>Model: modify embedding gradients / compute reg loss
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Poem

🐇
I hopped through lines where spellings changed their tune,
"Optimiser" now hums beneath the coding moon.
CBOW and SkipGram sprout with vocab tied so neat,
Embeddings sigh and shrink with each soft decay beat.
A rabbit claps — the update lands light on quick feet.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.87% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Feature/skip gram" is directly related to the primary change in the changeset—adding skip‑gram/Word2Vec functionality (e.g., SkipGramModel and CLI/sample code)—so it conveys the main intent of the PR succinctly, though it includes a branch-style prefix and minor formatting inconsistencies.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@modiase modiase force-pushed the feature/skip-gram branch from 009fe24 to 7d329e6 Compare July 31, 2025 06:20
modiase added 2 commits July 31, 2025 07:20
Seems like a silly idea. It's always just going to produce nonsense
cursor[bot]

This comment was marked as outdated.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
mo_net/samples/word2vec/__main__.py (1)

543-543: Verify model type compatibility in sample command.

The sample command explicitly loads a CBOWModel, which won't work for Skip-gram models. This should either support both model types or clearly indicate it only works with CBOW.

-    model = CBOWModel.load(model_path, training=False)
+    # Try to determine model type from the saved model
+    model = Model.load(model_path, training=False)
+    if not isinstance(model, (CBOWModel, SkipGramModel)):
+        raise click.ClickException(f"Model must be either CBOW or Skip-gram, got {type(model).__name__}")

Would you like me to implement a more robust model type detection mechanism that could store the model type in metadata alongside the vocabulary?

🧹 Nitpick comments (5)
.gitignore (1)

40-40: Scope the DB ignore to avoid surprises.

As written, this ignores any train.db at any depth. Prefer anchoring to the repo root, or ignore all .db files if multiple DBs may appear.

Root-anchored (conservative):

-train.db
+/train.db

Alternative (broader):

-train.db
+*.db
.envrc (1)

1-2: Optionally load a local .env for parity with app configs.

This helps devs who keep secrets/config in a .env (already gitignored).

 PATH_add bin
 source_env_if_exists .envrc.local
+dotenv_if_exists .env
mo_net/samples/word2vec/__main__.py (3)

125-125: Remove unused noqa directive.

The # noqa: F821 comment is unnecessary since HiddenLayer is properly imported on line 20.

-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],

465-468: Consider more descriptive variable naming for reshaping operation.

The reshaping logic for Skip-gram is correct, but the variable names could be clearer about what's being reshaped and why.

 if model_type == "skipgram":
+    # Reshape to (batch_size, 1) for single word input in Skip-gram
     X_train_split = X_train_split.reshape(-1, 1)
     X_val = X_val.reshape(-1, 1)

507-508: Clean up unnecessary variable renaming.

The variable never_2 appears to be an artifact of automatic refactoring. It should be simplified to just never for consistency.

-        case never_2:
-            assert_never(never_2)
+        case never:
+            assert_never(never)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a5d868 and 3068dab.

📒 Files selected for processing (4)
  • .envrc (1 hunks)
  • .gitignore (1 hunks)
  • bin/typecheck (1 hunks)
  • mo_net/samples/word2vec/__main__.py (11 hunks)
✅ Files skipped from review due to trivial changes (1)
  • bin/typecheck
🧰 Additional context used
🧬 Code graph analysis (1)
mo_net/samples/word2vec/__main__.py (9)
mo_net/__init__.py (1)
  • print_device_info (139-145)
mo_net/functions.py (1)
  • sparse_cross_entropy (20-23)
mo_net/model/layer/embedding.py (2)
  • Embedding (117-307)
  • vocab_size (306-307)
mo_net/model/model.py (6)
  • output (435-436)
  • Model (53-436)
  • loss_contributors (269-270)
  • output_module (290-291)
  • hidden_modules (276-277)
  • input_layer (286-287)
mo_net/model/layer/output.py (4)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-181)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
mo_net/regulariser/weight_decay.py (4)
  • EmbeddingWeightDecayRegulariser (71-112)
  • embedding_layer (68-68)
  • attach (47-63)
  • attach (101-112)
mo_net/samples/word2vec/vocab.py (3)
  • Vocab (26-121)
  • get_training_set (124-143)
  • english_sentences (115-121)
mo_net/train/trainer/trainer.py (3)
  • train (209-266)
  • BasicTrainer (98-401)
  • get_optimiser (51-89)
mo_net/config.py (1)
  • TrainingParameters (8-51)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py

125-125: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


383-383: Unused function argument: device

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: typecheck
🔇 Additional comments (6)
.envrc (2)

1-2: LGTM: minimal, safe direnv setup.

PATH_add bin and source_env_if_exists are standard and keep local overrides out of VCS.


1-2: Confirmed: bin/typecheck is present and executable.
Found bin/typecheck with permissions -rwxr-xr-x — no change required.

mo_net/samples/word2vec/__main__.py (4)

121-195: LGTM! Well-structured Skip-gram model implementation.

The SkipGramModel class is properly implemented with appropriate initialization, factory method, and configuration. The negative sampling approach is correctly integrated.


204-225: Solid negative sampling implementation in backward propagation.

The backward propagation correctly implements negative sampling using the specialized backward_prop_with_negative method. The random key management and gradient flow through layers are properly handled.


471-476: Good use of updated optimiser naming convention.

The refactoring from "optimizer" to "optimiser" is consistently applied throughout the code, maintaining British English spelling consistency.


401-406: Skip-gram swap verified — no change required.
get_training_set(...) returns (context, target) as jnp arrays (context has shape (N, 2*context_size), target has shape (N,)). Initially X_train=context, Y_train=target; the swap makes X_train=target and Y_train=context, which matches Skip‑gram (input=target → predict context).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
mo_net/model/layer/output.py (1)

8-8: Import bug: abstractmethod imported from pyparsing

abstractmethod must come from abc, not pyparsing. This will raise at import time.

-from pyparsing import abstractmethod
+from abc import abstractmethod
♻️ Duplicate comments (3)
mo_net/model/layer/output.py (1)

171-176: Fix for prior negative-sample indexing bug looks correct

Using repeat(arange(batch), num_neg) for rows and a flattened Y_negative for columns addresses the earlier out‑of‑bounds issue when negative_samples > 1. Good job.

Also applies to: 188-191

mo_net/samples/word2vec/__main__.py (2)

383-383: Unused device parameter still needs addressing.

This is a duplicate of a previous review comment. The device parameter is still unused in the function implementation.


469-469: Hardcoded run name still doesn't reflect model type.

This is a duplicate of a previous review comment. The run name is still hardcoded to "cbow_run_" regardless of the actual model type being trained.

🧹 Nitpick comments (5)
mo_net/model/layer/output.py (4)

165-167: Shorten error message to satisfy TRY003 (Ruff)

Message is flagged as “long message outside exception class.” A concise message is sufficient here.

-        if (output_activations := self._cache["output_activations"]) is None:
-            raise ValueError("Output activations not set during forward pass.")
+        if (output_activations := self._cache["output_activations"]) is None:
+            raise ValueError("Forward pass required.")

168-170: Unnecessary copy; .at[...] is functional already

jnp.ndarray.at returns a new array; copying first just adds an allocation.

-        result = output_activations.copy()
-        result = result.at[jnp.arange(Y_true.shape[0]), Y_true].add(-1.0)
+        batch_size = output_activations.shape[0]
+        result = output_activations.at[jnp.arange(batch_size), Y_true].add(-1.0)

171-193: Minor shape/dtype validations and TRY003 nits

  • Consider validating Y_true.ndim == 1 (sparse targets) to fail fast.
  • Add a brief dtype check to ensure Y_negative is integer to avoid cryptic JAX errors.
  • Shorten two error messages to satisfy Ruff TRY003.
         if Y_negative is not None:
+            # Validate dtypes/shapes early
+            if Y_true.ndim != 1:
+                raise ValueError("Y_true must be 1D indices.")
+            if not jnp.issubdtype(Y_negative.dtype, jnp.integer):
+                raise ValueError("Y_negative dtype must be integer.")
             if Y_negative.ndim == 2:
                 batch_size, num_neg = Y_negative.shape
                 row_idx = jnp.repeat(jnp.arange(batch_size), num_neg)
                 col_idx = Y_negative.reshape(-1)
                 result = result.at[row_idx, col_idx].add(1.0)
             elif Y_negative.ndim == 1:
                 # - (batch_size,) for a single negative per example
-                batch_size = result.shape[0]
+                batch_size = result.shape[0]
                 total = Y_negative.shape[0]
                 if total == batch_size:
                     result = result.at[jnp.arange(batch_size), Y_negative].add(1.0)
                 else:
                     if total % batch_size != 0:
-                        raise ValueError(
-                            "Y_negative length must be a multiple of batch size when flattened."
-                        )
+                        raise ValueError("Invalid Y_negative length (must be multiple of batch size).")
                     num_neg = total // batch_size
                     row_idx = jnp.repeat(jnp.arange(batch_size), num_neg)
                     result = result.at[row_idx, Y_negative].add(1.0)
             else:
-                raise ValueError("Y_negative must be 1D or 2D array of indices.")
+                raise ValueError("Y_negative must be 1D or 2D.")

194-194: Cast is fine; consider dropping atleast_1d

result is already 2D; jnp.atleast_1d is a no‑op. You can return result directly and keep the cast.

-        return cast(D[Activations], jnp.atleast_1d(result))
+        return cast(D[Activations], result)
mo_net/samples/word2vec/__main__.py (1)

125-125: Remove unused noqa directive.

The # noqa: F821 comment is no longer needed since the type annotation issue has been resolved.

-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3068dab and 5943931.

📒 Files selected for processing (2)
  • mo_net/model/layer/output.py (1 hunks)
  • mo_net/samples/word2vec/__main__.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/model/layer/output.py (1)
mo_net/protos.py (1)
  • D (18-26)
mo_net/samples/word2vec/__main__.py (7)
mo_net/__init__.py (1)
  • print_device_info (139-145)
mo_net/functions.py (1)
  • sparse_cross_entropy (20-23)
mo_net/model/model.py (4)
  • output (435-436)
  • Model (53-436)
  • loss_contributors (269-270)
  • output_module (290-291)
mo_net/model/layer/output.py (4)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-201)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
mo_net/regulariser/weight_decay.py (4)
  • EmbeddingWeightDecayRegulariser (71-112)
  • embedding_layer (68-68)
  • attach (47-63)
  • attach (101-112)
mo_net/samples/word2vec/vocab.py (3)
  • Vocab (26-121)
  • get_training_set (124-143)
  • english_sentences (115-121)
mo_net/train/trainer/trainer.py (2)
  • train (209-266)
  • get_optimiser (51-89)
🪛 Ruff (0.12.2)
mo_net/model/layer/output.py

166-166: Avoid specifying long messages outside the exception class

(TRY003)


185-187: Avoid specifying long messages outside the exception class

(TRY003)


192-192: Avoid specifying long messages outside the exception class

(TRY003)

mo_net/samples/word2vec/__main__.py

125-125: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


383-383: Unused function argument: device

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (8)
mo_net/model/layer/output.py (1)

151-158: API addition looks good; keep it discoverable

backward_prop_with_negative simply forwards to _backward_prop and is exercised at mo_net/samples/word2vec/main.py:210 — no further changes required.

mo_net/samples/word2vec/__main__.py (7)

7-7: LGTM! Good addition of type imports for SkipGram implementation.

The additional imports (Literal, assert_never, cast) are appropriately used throughout the new SkipGram implementation.


15-15: LGTM! Import updates align with the vocabulary refactoring and new features.

The import changes correctly reflect:

  • Moving vocabulary utilities to external modules
  • Adding support for output layers with negative sampling
  • Updating to use the modernized "optimiser" spelling
  • Adding necessary type imports for the new SkipGram model

Also applies to: 17-17, 20-20, 23-23, 26-28


121-233: LGTM! Well-implemented SkipGram model with proper negative sampling.

The SkipGramModel implementation correctly:

  • Uses single-word input (input_dimensions=(1,)) vs CBOW's context window
  • Implements negative sampling in backward_prop with proper tensor reshaping
  • Overrides compute_loss to handle flattened targets appropriately
  • Maintains the same embedding layer interface for consistency

The negative sampling logic properly generates random negative samples and passes them to the output layer's backward_prop_with_negative method.


279-369: LGTM! Training options correctly expanded for SkipGram support.

The CLI options have been appropriately updated to support both CBOW and SkipGram models:

  • Added --model-type with choices for "cbow" and "skipgram"
  • Added --negative-samples for SkipGram configuration
  • Added --model-path for loading existing models
  • Updated parameter names and defaults appropriately

377-509: LGTM! Train function properly handles both model types.

The training function correctly:

  • Uses external vocabulary utilities from vocab.py
  • Swaps X_train/Y_train for SkipGram models (lines 404-405)
  • Creates appropriate model instances based on model_type
  • Reshapes input data for SkipGram (lines 465-467)
  • Uses modernized "optimiser" spelling consistently

The model loading and creation logic is sound, and the data preparation handles the different input/output requirements of CBOW vs SkipGram appropriately.


470-476: LGTM! Proper optimiser modernization and type casting.

The code correctly:

  • Uses the modernized get_optimiser function name
  • Casts the model to the union type for the regulariser attachment
  • Maintains backward compatibility with existing CBOW models

401-401: Vocab.english_sentences integration verified — no changes required. The classmethod exists at mo_net/samples/word2vec/vocab.py, accepts max_vocab_size as a keyword, and delegates to get_english_sentences(limit) → from_sentences(...), which builds the most-common-token vocab and returns tokenized sentences; get_training_set returns (jnp.ndarray, jnp.ndarray).

- Update import path for check_device and DeviceType in word2vec
@modiase modiase force-pushed the feature/skip-gram branch 2 times, most recently from 9cbb730 to 1e515cc Compare September 16, 2025 15:05
Copy link
Owner Author

@modiase modiase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@modiase reviewed 12 of 18 files at r1, 4 of 4 files at r2, 5 of 11 files at r3, 11 of 11 files at r5, 12 of 12 files at r7, all commit messages.
Reviewable status: :shipit: complete! all files reviewed, all discussions resolved (waiting on @modiase)

cursor[bot]

This comment was marked as outdated.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
mo_net/model/layer/output.py (1)

8-8: Blocking: wrong abstractmethod import.

This module imports abstractmethod from pyparsing, which will fail. It should come from abc.

Use:

from abc import abstractmethod
♻️ Duplicate comments (3)
flake.nix (1)

100-103: Remove hard‑coded Nix store driver path from LD_LIBRARY_PATH; use derivation path instead.

The pinned store path is non‑portable and will rot on driver/kernel updates. This was already flagged earlier. Use nvidia_x11’s derivation path and gate to Linux.

-            # Set up library paths with NVIDIA driver libraries prioritized for JAX
-            # Use the specific NVIDIA driver version that matches the kernel (570.153.02)
-            export LD_LIBRARY_PATH="/nix/store/6rf8qnkp4m5h5x9byf5srphcdjdp5r5j-nvidia-x11-570.153.02-6.12.35/lib:${pkgs.lib.makeLibraryPath (systemLibs ++ cudaLibs)}:''${LD_LIBRARY_PATH:-}"
+            # Set up library paths with NVIDIA driver libraries prioritized for JAX (Linux only)
+            export LD_LIBRARY_PATH="${
+              pkgs.lib.optionalString pkgs.stdenv.isLinux "${pkgs.lib.makeLibraryPath [ pkgs.linuxPackages.nvidia_x11 ]}:"
+            }${pkgs.lib.makeLibraryPath (systemLibs ++ cudaLibs)}:''${LD_LIBRARY_PATH:-}"
mo_net/model/layer/output.py (1)

171-176: Previous negative-sample indexing bug is fixed.

Using jnp.repeat(jnp.arange(batch_size), num_neg) resolves the earlier out‑of‑bounds issue when negatives > 1. Nice.

mo_net/samples/word2vec/__main__.py (1)

462-471: LGTM!

Good updates to use the model type in the training run name and proper casting for the regularizer attachment.

Note: This addresses the previous review comment about the hardcoded "cbow_run_" name by using the dynamic model_type variable.

🧹 Nitpick comments (2)
mo_net/model/layer/output.py (2)

159-165: Minor: shorten error strings (Ruff TRY003) or silence per-line.

If you don’t want custom exception types, keep messages terse and add “# noqa: TRY003” where needed (as shown in the diff above).


159-165: Fix signature mismatch in OutputLayer._backward_prop

Abstract OutputLayer._backward_prop is declared without Y_negative (mo_net/model/layer/output.py:45) but the concrete override adds Y_negative (mo_net/model/layer/output.py:159) — static type checkers (mypy/pyright) will flag this as an incompatible override.

Options:

  • Preferred: keep the abstract signature unchanged and add a private helper (e.g., _backward_prop_with_negative) used by the method that accepts Y_negative.
  • Alternative: widen the abstract signature to accept Y_negative: jnp.ndarray | None = None and update all overrides.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5943931 and 05b8835.

📒 Files selected for processing (3)
  • flake.nix (2 hunks)
  • mo_net/model/layer/output.py (1 hunks)
  • mo_net/samples/word2vec/__main__.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/samples/word2vec/__main__.py (6)
mo_net/model/layer/embedding.py (2)
  • Embedding (117-307)
  • vocab_size (306-307)
mo_net/model/model.py (5)
  • output (435-436)
  • Model (53-436)
  • loss_contributors (269-270)
  • output_module (290-291)
  • hidden_modules (276-277)
mo_net/model/layer/output.py (4)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-201)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
mo_net/regulariser/weight_decay.py (4)
  • EmbeddingWeightDecayRegulariser (71-112)
  • embedding_layer (68-68)
  • attach (47-63)
  • attach (101-112)
mo_net/samples/word2vec/vocab.py (3)
  • Vocab (26-121)
  • get_training_set (124-143)
  • english_sentences (115-121)
mo_net/train/trainer/trainer.py (2)
  • train (209-266)
  • get_optimiser (51-89)
mo_net/model/layer/output.py (2)
mo_net/protos.py (1)
  • D (18-26)
mo_net/model/layer/base.py (1)
  • _backward_prop (107-111)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py

125-125: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

mo_net/model/layer/output.py

166-166: Avoid specifying long messages outside the exception class

(TRY003)


185-187: Avoid specifying long messages outside the exception class

(TRY003)


192-192: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: typecheck
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (16)
mo_net/model/layer/output.py (1)

151-158: Document shapes for backward_prop_with_negative

Add a concise docstring on mo_net/model/layer/output.py::backward_prop_with_negative describing expected shapes:

  • Y_true: jnp.ndarray[int], shape (N,) — N = number of targets (e.g. batch or batch*context when flattened).
  • Y_negative: jnp.ndarray[int], shape (N, K) — K negative samples per target (or documented as flattened (N*K,) if callers use that).

Caller evidence: mo_net/samples/word2vec/main.py:210 calls backward_prop_with_negative(Y_true=Y_true.flatten(), Y_negative=jax.random.choice(...)). Keep the original backward_prop for existing callers.

mo_net/samples/word2vec/__main__.py (15)

7-7: LGTM!

Good addition of the assert_never import to handle exhaustive pattern matching for the new model type selection.


17-17: LGTM!

Correct import of LossFn type for the new compute_loss method signature.


20-20: LGTM!

Smart aliasing of Hidden as HiddenLayer for compatibility with the new SkipGramModel constructor.


23-23: LGTM!

Good import addition of OutputLayer for the new backward propagation functionality.


26-28: LGTM!

Excellent modularization by extracting vocabulary utilities to a separate module and importing the regularizer for weight decay functionality.


36-36: LGTM!

Good update to use the standardized British spelling of "optimiser" for consistency across the codebase.


279-363: LGTM!

The CLI option additions are well-structured and comprehensive. The new options for --model-type, --negative-samples, and other parameters properly support both CBOW and Skip-gram training workflows.


371-388: LGTM!

The function signature properly incorporates all the new parameters needed for both model types, including the model_type discriminator and negative_samples for skip-gram training.


394-398: LGTM!

Smart data preparation logic that correctly swaps X and Y for skip-gram training. This aligns with the skip-gram model where you predict context words from a center word, opposite to CBOW.


409-430: LGTM!

Excellent model creation logic using pattern matching to handle both model types. The parameters are correctly passed to each model's factory method.


458-461: LGTM!

Correct data reshaping for skip-gram input format, ensuring the input has the expected shape of (-1, 1) for single-word inputs.


479-479: LGTM!

Consistent use of British spelling "optimiser" throughout the training setup.


486-488: LGTM!

Good logging message that reflects the actual model type being trained.


502-503: LGTM!

Proper use of assert_never with a unique variable name to handle exhaustive pattern matching.


121-234: Approve — confirm removal of noqa

  • Remove the unused # noqa: F821 on Line 125 only after confirming Hidden / HiddenLayer are defined or imported (quick repo search didn’t locate their definitions).
  • Negative sampling: Verified — negatives are sampled as batch_size * context_size * negative_samples (k negatives per positive) and passed to the output layer; aligns with expected skip‑gram negative sampling.
  • Model architecture: Verified — input_dimensions=(1,) is correct for skip‑gram; training flips X/Y for skipgram and reshapes inputs to (-1, 1).
-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],

@modiase modiase force-pushed the feature/skip-gram branch 3 times, most recently from 2577878 to 0315e1f Compare September 16, 2025 16:13
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
mo_net/samples/word2vec/__main__.py (6)

130-139: Remove unnecessary noqa and tighten type hint.

# noqa: F821 is unnecessary because HiddenLayer is imported. Clean it up.

-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],

417-435: Unify tokenized_sentences type across branches to satisfy the type checker.

CI error at Line 417 shows mismatch (Collection[Sequence[int]] vs list[list[int]]). The annotation in the previous comment resolves it. If you prefer a concrete type, convert Vocab.english_sentences(...)[1] to list[list[int]].

-        vocab, tokenized_sentences = Vocab.english_sentences(max_vocab_size=vocab_size)
+        vocab, tokenized = Vocab.english_sentences(max_vocab_size=vocab_size)
+        tokenized_sentences: Collection[Sequence[int]] = tokenized

480-486: Regulariser attach relies on subclass typing.

This will fail when loading a base Model. After applying the subclass‑aware load fix above, this cast is correct. Consider annotating model: CBOWModel | SkipGramModel when created/loaded to help type checkers.

-    optimiser = get_optimiser("adam", model, training_parameters)
+    model = cast(CBOWModel | SkipGramModel, model)
+    optimiser = get_optimiser("adam", model, training_parameters)

564-586: Exception chaining and logging fix in sample; also satisfy type checker.

Chain exceptions to preserve cause (B904) and declare the union type for model to avoid the mypy error at Line 576. Also avoid passing e to logger.exception unnecessarily.

-    try:
+    try:
         with zipfile.ZipFile(model_path, "r") as zf:
             vocab_bytes = zf.read(VOCAB_ZIP_INTERNAL_PATH)
             with zf.open(METADATA_ZIP_INTERNAL_PATH) as md:
                 metadata = json.loads(md.read().decode("utf-8"))
                 model_type = metadata.get("type", "cbow")
 
-            with zf.open(MODEL_ZIP_INTERNAL_PATH) as mf:
+            model: CBOWModel | SkipGramModel
+            with zf.open(MODEL_ZIP_INTERNAL_PATH) as mf:
                 match model_type:
                     case "skipgram":
                         model = SkipGramModel.load(mf, training=False)
                     case "cbow":
                         model = CBOWModel.load(mf, training=False)
                     case never:
                         assert_never(never)
     except KeyError as e:
-        raise click.ClickException(
+        raise click.ClickException(
             f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}"
-        )
+        ) from e
     except zipfile.BadZipFile as e:
-        logger.exception(f"Invalid zip file: {model_path}", e)
-        raise click.ClickException(f"Invalid zip file: {model_path}")
+        logger.exception(f"Invalid zip file: {model_path}")
+        raise click.ClickException(f"Invalid zip file: {model_path}") from e

589-597: Minor efficiency nit in sampling.

list(vocab.vocab) is created repeatedly; cache once for readability (micro‑opt).

-    random_words = [list(vocab.vocab)[int(i)] for i in word_indices]
+    vocab_list = list(vocab.vocab)
+    random_words = [vocab_list[int(i)] for i in word_indices]

217-234: Optional: avoid sampling positives as negatives.

Uniform negatives can include true labels; usually OK, but you can resample to exclude positives to marginally stabilise training.

Would you like a small helper to draw negatives excluding Y_true per row?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2577878 and 63365ac.

📒 Files selected for processing (2)
  • mo_net/samples/word2vec/__main__.py (10 hunks)
  • mo_net/samples/word2vec/vocab.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • mo_net/samples/word2vec/vocab.py
🧰 Additional context used
🧬 Code graph analysis (1)
mo_net/samples/word2vec/__main__.py (4)
mo_net/model/model.py (4)
  • output (435-436)
  • output_module (290-291)
  • hidden_modules (276-277)
  • input_layer (286-287)
mo_net/model/layer/output.py (7)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-201)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
  • serialize (52-52)
  • serialize (94-95)
  • serialize (200-201)
mo_net/regulariser/weight_decay.py (4)
  • EmbeddingWeightDecayRegulariser (71-112)
  • embedding_layer (68-68)
  • attach (47-63)
  • attach (101-112)
mo_net/samples/word2vec/vocab.py (6)
  • Vocab (21-129)
  • get_english_sentences (296-305)
  • get_training_set (132-151)
  • from_bytes (119-129)
  • english_sentences (110-116)
  • serialize (27-34)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py

134-134: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


580-582: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


580-582: Avoid specifying long messages outside the exception class

(TRY003)


585-585: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


585-585: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: CI
mo_net/samples/word2vec/__main__.py

[error] 417-417: uv run task typecheck -v: Incompatible types in assignment (expression has type "Collection[Sequence[int]]", variable has type "list[list[int]]") [assignment]


[error] 576-576: uv run task typecheck -v: Incompatible types in assignment (expression has type "CBOWModel", variable has type "SkipGramModel") [assignment]

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (4)
mo_net/samples/word2vec/__main__.py (4)

473-476: Reshape centers only for Skip‑gram: LGTM.
This ensures X has shape (batch, 1) for the embedding layer. No action needed.


501-503: Nice training log message.
Clear and parameterised by model_type.


509-526: ZIP save flow: LGTM.
Buffered write, deterministic internal paths, and ensuring parent dirs exist are good practices.


321-349: CLI options: good coverage.
Options for optimiser regularisation, negatives, and logging are clear. No changes required.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (8)
mo_net/samples/word2vec/vocab.py (3)

63-93: Token normalisation is missing when building the vocab and tokenization.

from_sentences counts tokens and maps them without calling clean_token. As a result, punctuation/case variants inflate the vocabulary and “The” won’t be filtered by lower‑case stopwords.

Apply:

-        most_common_tokens = [
-            token
-            for token, _ in Counter(
-                token for sentence in sentences if sentence for token in sentence
-            ).most_common(max_size)
-        ]
+        most_common_tokens = [
+            token
+            for token, _ in Counter(
+                t
+                for sentence in sentences
+                if sentence
+                for raw in sentence
+                if (t := cls.clean_token(raw))
+            ).most_common(max_size)
+        ]
@@
-            [
-                [vocab[token] for token in sentence]
-                for sentence in sentences
-                if sentence
-            ],
+            [
+                [vocab[cls.clean_token(raw)] for raw in sentence if cls.clean_token(raw)]
+                for sentence in sentences
+                if sentence
+            ],

95-101: Clarify length semantics.

__len__ returns len(self.vocab) + 1 (to include the unknown token). This is fine, but it’s surprising. Add a short docstring to avoid misuse where callers iterate range(len(vocab)) expecting only known tokens.

-    def __len__(self) -> int:
-        return len(self.vocab) + 1
+    def __len__(self) -> int:
+        """Vocabulary size including <unknown>."""
+        return len(self.vocab) + 1

132-152: Memory pressure: build context/target lazily or preallocate.

The list comprehension materialises all pairs before zipping; on large corpora this is heavy.

Example generator-based approach:

-    context, target = zip(
-        *[
-            (
-                tuple(
-                    chain(
-                        sentence[i - context_size : i],
-                        sentence[i + 1 : i + context_size + 1],
-                    )
-                ),
-                sentence[i],
-            )
-            for sentence in tokenized_sentences
-            for i in range(context_size, len(sentence) - context_size)
-        ],
-        strict=True,
-    )
+    pairs = (
+        (
+            tuple(
+                chain(
+                    sentence[i - context_size : i],
+                    sentence[i + 1 : i + context_size + 1],
+                )
+            ),
+            sentence[i],
+        )
+        for sentence in tokenized_sentences
+        for i in range(context_size, len(sentence) - context_size)
+    )
+    context, target = zip(*pairs, strict=True)
mo_net/samples/word2vec/__main__.py (5)

134-134: Remove unused # noqa: F821.

HiddenLayer is imported (Line 24), so the noqa is unnecessary and flagged by Ruff.

-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],

205-212: Ensure scalar loss type.

Depending on loss_fn, compute_loss may return a JAX array; downstream code often expects a Python float.

-        return loss_fn(Y_pred, Y_true.flatten()) + sum(
+        return float(loss_fn(Y_pred, Y_true.flatten())) + sum(
             contributor() for contributor in self.loss_contributors
         )

439-441: Double-check CBOW/Skip-gram train set swap.

The swap is correct for Skip‑gram. Add a brief comment so future edits don’t regress this.

-    if model_type == "skipgram":
+    if model_type == "skipgram":
+        # For Skip-gram, inputs are centers and labels are contexts.
         Y_train, X_train = X_train, Y_train

473-476: Reshape only when needed and keep dtypes.

The reshape for Skip‑gram inputs is required (centers as (B,1)). Confirm X_train.dtype is integer; cast if needed to avoid JAX implicit promotion.

-    if model_type == "skipgram":
-        X_train_split = X_train_split.reshape(-1, 1)
-        X_val = X_val.reshape(-1, 1)
+    if model_type == "skipgram":
+        X_train_split = jnp.asarray(X_train_split, dtype=jnp.int32).reshape(-1, 1)
+        X_val = jnp.asarray(X_val, dtype=jnp.int32).reshape(-1, 1)

580-585: Chain exceptions and fix logger.exception usage.

Use raise ... from e (B904) and don’t pass e as a positional arg to logger.exception (it’s implicit).

-    except KeyError as e:
+    except KeyError as e:
         raise click.ClickException(
             f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}"
-        )
-    except zipfile.BadZipFile as e:
-        logger.exception(f"Invalid zip file: {model_path}", e)
-        raise click.ClickException(f"Invalid zip file: {model_path}")
+        ) from e
+    except zipfile.BadZipFile as e:
+        logger.exception("Invalid zip file: {}", model_path)
+        raise click.ClickException(f"Invalid zip file: {model_path}") from e
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2577878 and 63365ac.

📒 Files selected for processing (2)
  • mo_net/samples/word2vec/__main__.py (10 hunks)
  • mo_net/samples/word2vec/vocab.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/samples/word2vec/__main__.py (4)
mo_net/model/model.py (4)
  • output (435-436)
  • output_module (290-291)
  • hidden_modules (276-277)
  • input_layer (286-287)
mo_net/model/layer/output.py (7)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-201)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
  • serialize (52-52)
  • serialize (94-95)
  • serialize (200-201)
mo_net/regulariser/weight_decay.py (4)
  • EmbeddingWeightDecayRegulariser (71-112)
  • embedding_layer (68-68)
  • attach (47-63)
  • attach (101-112)
mo_net/samples/word2vec/vocab.py (6)
  • Vocab (21-129)
  • get_english_sentences (296-305)
  • get_training_set (132-151)
  • from_bytes (119-129)
  • english_sentences (110-116)
  • serialize (27-34)
mo_net/samples/word2vec/vocab.py (1)
mo_net/resources.py (1)
  • get_resource (16-47)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py

134-134: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


580-582: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


580-582: Avoid specifying long messages outside the exception class

(TRY003)


585-585: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


585-585: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: CI
mo_net/samples/word2vec/__main__.py

[error] 417-417: Mypy: Incompatible types in assignment. Expression has type 'Collection[Sequence[int]]', variable has type 'list[list[int]]' (assignment). Step: 'uv run task typecheck -v'.


[error] 576-576: Mypy: Incompatible types in assignment. Expression has type 'CBOWModel', variable has type 'SkipGramModel' (assignment). Step: 'uv run task typecheck -v'.

🔇 Additional comments (3)
mo_net/samples/word2vec/__main__.py (3)

477-486: OK: run naming and regulariser wiring.

Run names reflect model type; weight‑decay regulariser is correctly attached to the embedding layer.

Consider persisting negative_samples in metadata for reproducibility.


501-531: Good: self-contained zip artefact.

Writing model.pkl, vocab.msgpack and metadata.json into a single zip is clean and reproducible; defaulting output path to DATA_DIR is sensible.


213-226: Incorrect — no shape mismatch; labels and outputs are flattened consistently.

forward_prop returns predictions flattened across context positions (compute_loss calls Y_pred = self.forward_prop(X) and compares to Y_true.flatten()); backward_prop passes Y_true.flatten() and flattened negatives, and SparseCategoricalSoftmaxOutputLayer._backward_prop indexes rows by Y_true.shape[0] against the cached output_activations, so the shapes align.

Likely an incorrect or invalid review comment.

@modiase modiase force-pushed the feature/skip-gram branch 2 times, most recently from 5f6bca3 to 4cdfe81 Compare September 16, 2025 17:14
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
mo_net/samples/word2vec/vocab.py (2)

16-17: Avoid Python 3.12-only type aliases; use TypeAlias.

Apply:

-type Sentence = Sequence[str]
-type TokenizedSentence = Sequence[int]
+Sentence: TypeAlias = Sequence[str]
+TokenizedSentence: TypeAlias = Sequence[int]

296-305: Normalize tokens before stopword filtering.

Apply:

-def get_english_sentences(limit: int = 100000) -> Collection[Sentence]:
+def get_english_sentences(limit: int = 100000) -> Collection[Sentence]:
     stop_words = get_stop_words()
     return [
-        [word for word in sentence.split() if word not in stop_words]
+        [
+            tok
+            for tok in (Vocab.clean_token(w) for w in sentence.split())
+            if tok and tok not in stop_words
+        ]
         for sentence in (
             get_resource("s3://mo-net-resources/english-sentences.txt")
             .read_text()
             .split("\n")[:limit]
         )
     ]
mo_net/samples/word2vec/__main__.py (2)

275-281: Skip‑gram loss: shape mismatch between logits and flattened labels.

Y_true.flatten() has length batch_size*num_pos but Y_pred is (batch_size, vocab). Repeat logits per positive target before loss.

Apply:

     def compute_loss(
         self, X: jnp.ndarray, Y_true: jnp.ndarray, loss_fn: LossFn
     ) -> float:
-        Y_pred = self.forward_prop(X)
-        return loss_fn(Y_pred, Y_true.flatten()) + sum(
+        Y_pred = self.forward_prop(X)
+        if Y_true.ndim == 2:
+            num_pos = Y_true.shape[1]
+            Y_pred = jnp.repeat(Y_pred, repeats=num_pos, axis=0)
+            y = Y_true.reshape(-1)
+        else:
+            y = Y_true
+        return loss_fn(Y_pred, y) + sum(
             contributor() for contributor in self.loss_contributors
         )

283-303: Skip‑gram backprop: row indexing overruns due to flattened labels.

Output-layer gradient uses jnp.arange(Y_true.shape[0]). Flattening makes rows exceed logits rows. Aggregate per positive after repeating.

Apply:

-    def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]:
-        batch_size, context_size = Y_true.shape
-        self._key, subkey = jax.random.split(self._key)
-
-        dZ = cast(
-            SparseCategoricalSoftmaxOutputLayer, self.output.output_layer
-        ).backward_prop_with_negative(
-            Y_true=Y_true.flatten(),
-            Y_negative=jax.random.choice(
-                subkey,
-                self.embedding_layer.vocab_size,
-                shape=(batch_size * context_size * self._negative_samples,),
-            ),
-        )
+    def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]:
+        if Y_true.ndim == 2:
+            batch_size, num_pos = Y_true.shape
+            self._key, subkey = jax.random.split(self._key)
+            dZ_rep = cast(
+                SparseCategoricalSoftmaxOutputLayer, self.output.output_layer
+            ).backward_prop_with_negative(
+                Y_true=Y_true.reshape(-1),
+                Y_negative=jax.random.choice(
+                    subkey,
+                    self.embedding_layer.vocab_size,
+                    shape=(batch_size * num_pos * self._negative_samples,),
+                ),
+            )
+            # Sum per example across its positive contexts → shape (batch_size, vocab)
+            dZ = dZ_rep.reshape(batch_size, num_pos, -1).sum(axis=1)
+        else:
+            batch_size = Y_true.shape[0]
+            self._key, subkey = jax.random.split(self._key)
+            dZ = cast(
+                SparseCategoricalSoftmaxOutputLayer, self.output.output_layer
+            ).backward_prop_with_negative(
+                Y_true=Y_true,
+                Y_negative=jax.random.choice(
+                    subkey,
+                    self.embedding_layer.vocab_size,
+                    shape=(batch_size * self._negative_samples,),
+                ),
+            )
🧹 Nitpick comments (7)
bin/mnist (1)

1-2: Use exec, quote paths, and prefer module invocation (-m).

  • Replace path execution with python -m mo_net.samples.mnist to avoid relying on git and path spaces. Also use exec so the wrapper forwards signals/exit code.

Apply:

-uv run python $(git rev-parse --show-toplevel)/mo_net/samples/mnist "$@"
+exec uv run python -m mo_net.samples.mnist "$@"
mo_net/samples/word2vec/vocab.py (3)

38-40: Msgpack decoding: ensure string keys/values are decoded.

Be explicit with raw=False to avoid bytes when unpacking across environments.

Apply:

-            data = msgpack.unpackb(f.read())
+            data = msgpack.unpackb(f.read(), raw=False)

119-121: Msgpack decoding: be explicit here as well.

Apply:

-        obj = msgpack.unpackb(data)
+        obj = msgpack.unpackb(data, raw=False)

132-151: get_training_set: empty-corpus safety (zip strict=True raises).

For short corpora or large context_size, the list may be empty and zip(..., strict=True) raises ValueError. Guard for empty.

Apply:

-def get_training_set(
+def get_training_set(
     tokenized_sentences: Collection[TokenizedSentence], context_size: int
 ) -> tuple[jnp.ndarray, jnp.ndarray]:
-    context, target = zip(
-        *[
-            (
-                tuple(
-                    chain(
-                        sentence[i - context_size : i],
-                        sentence[i + 1 : i + context_size + 1],
-                    )
-                ),
-                sentence[i],
-            )
-            for sentence in tokenized_sentences
-            for i in range(context_size, len(sentence) - context_size)
-        ],
-        strict=True,
-    )
-    return jnp.array(context), jnp.array(list(target))
+    pairs = [
+        (
+            tuple(
+                chain(
+                    sentence[i - context_size : i],
+                    sentence[i + 1 : i + context_size + 1],
+                )
+            ),
+            sentence[i],
+        )
+        for sentence in tokenized_sentences
+        for i in range(context_size, len(sentence) - context_size)
+    ]
+    if not pairs:
+        return jnp.empty((0, context_size * 2), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
+    context, target = zip(*pairs)
+    return jnp.array(context), jnp.array(target)
mo_net/samples/word2vec/__main__.py (3)

204-204: Remove unused noqa.

# noqa: F821 is unnecessary here; imports satisfy the name.

Apply:

-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],

713-719: Exception chaining and logging.

Use raise ... from e and avoid passing exception objects to logger.exception (it logs the active exception automatically).

Apply:

-    except KeyError as e:
-        raise click.ClickException(
-            f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}"
-        )
-    except zipfile.BadZipFile as e:
-        logger.exception(f"Invalid zip file: {model_path}", e)
-        raise click.ClickException(f"Invalid zip file: {model_path}")
+    except KeyError as e:
+        raise click.ClickException(
+            f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}"
+        ) from e
+    except zipfile.BadZipFile as e:
+        logger.exception(f"Invalid zip file: {model_path}")
+        raise click.ClickException(f"Invalid zip file: {model_path}") from e

739-751: Cast JAX scalars to float for sorting/formatting.

jnp.array scalars don’t always support Python float formatting and comparisons for sorting; convert to float.

Apply:

-        similarities = []
+        similarities: list[tuple[str, float]] = []
         for other_word in vocab.vocab:
             if other_word != word:
                 other_id = vocab[other_word]
                 other_embedding = model.embeddings[other_id]
-                similarity = jnp.dot(word_embedding, other_embedding) / (
+                similarity = jnp.dot(word_embedding, other_embedding) / (
                     jnp.linalg.norm(word_embedding) * jnp.linalg.norm(other_embedding)
                 )
-                similarities.append((other_word, similarity))
+                similarities.append((other_word, float(similarity)))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 63365ac and 86b4689.

📒 Files selected for processing (4)
  • bin/mnist (1 hunks)
  • bin/word2vec (1 hunks)
  • mo_net/samples/word2vec/__main__.py (1 hunks)
  • mo_net/samples/word2vec/vocab.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • bin/word2vec
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/samples/word2vec/vocab.py (1)
mo_net/resources.py (1)
  • get_resource (16-47)
mo_net/samples/word2vec/__main__.py (10)
mo_net/__init__.py (1)
  • print_device_info (139-145)
mo_net/functions.py (1)
  • sparse_cross_entropy (20-23)
mo_net/log.py (2)
  • LogLevel (12-18)
  • setup_logging (21-23)
mo_net/model/layer/embedding.py (2)
  • Embedding (117-307)
  • vocab_size (306-307)
mo_net/model/model.py (6)
  • output (435-436)
  • Model (53-436)
  • hidden_modules (276-277)
  • output_module (290-291)
  • input_layer (286-287)
  • loss_contributors (269-270)
mo_net/model/layer/output.py (15)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-201)
  • Serialized (57-69)
  • Serialized (104-114)
  • Serialized (127-139)
  • output_dimensions (91-92)
  • output_dimensions (197-198)
  • serialize (52-52)
  • serialize (94-95)
  • serialize (200-201)
  • deserialize (60-69)
  • deserialize (107-114)
  • deserialize (130-139)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
mo_net/samples/word2vec/vocab.py (4)
  • get_training_set (132-151)
  • serialize (27-34)
  • deserialize (37-48)
  • from_bytes (119-129)
mo_net/train/trainer/trainer.py (4)
  • train (209-266)
  • BasicTrainer (98-401)
  • TrainingFailed (38-42)
  • TrainingSuccessful (33-34)
mo_net/config.py (1)
  • TrainingParameters (8-51)
mo_net/train/run.py (2)
  • TrainingRun (6-70)
  • seed (30-31)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py

174-174: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


176-176: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


178-178: Prefer TypeError exception for invalid type

(TRY004)


178-178: Avoid specifying long messages outside the exception class

(TRY003)


204-204: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


344-344: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


346-346: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


348-348: Prefer TypeError exception for invalid type

(TRY004)


348-348: Avoid specifying long messages outside the exception class

(TRY003)


690-690: Avoid specifying long messages outside the exception class

(TRY003)


713-715: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


713-715: Avoid specifying long messages outside the exception class

(TRY003)


718-718: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


718-718: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: typecheck

cursor[bot]

This comment was marked as outdated.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (7)
mo_net/samples/word2vec/vocab.py (3)

306-315: Normalize and clean tokens before stopword filtering.

Otherwise capitalized tokens and punctuation slip through.

-    return [
-        [word for word in sentence.split() if word not in stop_words]
+    return [
+        [
+            tok
+            for tok in (Vocab.clean_token(w) for w in sentence.split())
+            if tok and tok not in stop_words
+        ]
         for sentence in (
             get_resource("s3://mo-net-resources/english-sentences.txt")
             .read_text()
             .split("\n")[:limit]
         )
     ]

16-17: Avoid Python 3.12-only type alias syntax; use TypeAlias.

Keeps <=3.11 compatibility.

-type Sentence = Sequence[str]
-type TokenizedSentence = Sequence[int]
+Sentence: TypeAlias = Sequence[str]
+TokenizedSentence: TypeAlias = Sequence[int]

9-9: Bug: typing.Counter used at runtime; import from collections.

typing.Counter is not instantiable. This will crash when building vocab.

Apply:

-from typing import Counter, Self
+from collections import Counter
+from typing import Self, TypeAlias
mo_net/train/trainer/trainer.py (1)

7-8: Fix Python 3.12-only type aliases and include missing “rmsprop”.

Use TypeAlias and add "rmsprop" to the union to match the match cases.

-from typing import Any, ContextManager, Final, Literal, assert_never
+from typing import Any, ContextManager, Final, Literal, TypeAlias, assert_never
@@
-type TrainingResult = TrainingSuccessful | TrainingFailed
+TrainingResult: TypeAlias = TrainingSuccessful | TrainingFailed
@@
-type OptimizerType = Literal["adam", "none"]
+OptimizerType: TypeAlias = Literal["adam", "none", "rmsprop"]

Also applies to: 47-52

mo_net/samples/word2vec/__main__.py (3)

545-556: Keep model_type consistent with loaded ZIP metadata.

After loading, align CLI model_type to loaded_model_type so data prep and logging match the artifact.

                 if loaded_model_type == "skipgram":
                     model: CBOWModel | SkipGramModel = SkipGramModel.load(
                         mf, training=True, key=key
                     )
                 else:
                     model = CBOWModel.load(mf, training=True)
                 vocab = Vocab.from_bytes(zf.read(VOCAB_ZIP_INTERNAL_PATH))
         sentences = get_english_sentences()
+        model_type = loaded_model_type

275-281: Skip‑gram loss: logits/label shape mismatch.

Y_true.flatten() has length batch_size * num_pos but Y_pred has batch_size rows. Repeat logits per positive or collapse per column before loss.

     def compute_loss(
         self, X: jnp.ndarray, Y_true: jnp.ndarray, loss_fn: LossFn
     ) -> float:
-        Y_pred = self.forward_prop(X)
-        return loss_fn(Y_pred, Y_true.flatten()) + sum(
+        Y_pred = self.forward_prop(X)
+        if Y_true.ndim == 2:
+            # repeat each row once per positive context
+            Y_pred = jnp.repeat(Y_pred, repeats=Y_true.shape[1], axis=0)
+            y = Y_true.reshape(-1)
+        else:
+            y = Y_true
+        return loss_fn(Y_pred, y) + sum(
             contributor() for contributor in self.loss_contributors
         )

283-303: Skip‑gram backward: positive indices overrun row count.

Flattening positives yields batch_size * num_pos rows while the output layer has only batch_size rows. Aggregate per positive column.

-    def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]:
-        batch_size, context_size = Y_true.shape
-        self._key, subkey = jax.random.split(self._key)
-
-        dZ = cast(
-            SparseCategoricalSoftmaxOutputLayer, self.output.output_layer
-        ).backward_prop_with_negative(
-            Y_true=Y_true.flatten(),
-            Y_negative=jax.random.choice(
-                subkey,
-                self.embedding_layer.vocab_size,
-                shape=(batch_size * context_size * self._negative_samples,),
-            ),
-        )
+    def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]:
+        batch_size, num_pos = Y_true.shape
+        # Split RNG: one for state, one per positive column
+        splits = jax.random.split(self._key, num_pos + 1)
+        self._key = splits[0]
+        dZ = 0
+        for i in range(num_pos):
+            negatives = jax.random.choice(
+                splits[i + 1],
+                self.embedding_layer.vocab_size,
+                shape=(batch_size * self._negative_samples,),
+            )
+            dZ_i = cast(
+                SparseCategoricalSoftmaxOutputLayer, self.output.output_layer
+            ).backward_prop_with_negative(
+                Y_true=Y_true[:, i],
+                Y_negative=negatives,
+            )
+            dZ = dZ + dZ_i
🧹 Nitpick comments (4)
mo_net/samples/word2vec/vocab.py (1)

70-81: Make vocab ordering deterministic (don’t go through a set).

Using a set loses frequency order and makes IDs unstable between runs. Keep most_common order and prepend forced_words stably.

-        most_common_tokens = {
-            token
-            for token, _ in Counter(
-                token for sentence in sentences if sentence for token in sentence
-            ).most_common(max_size)
-        }
-        for word in forced_words:
-            most_common_tokens.add(word)
-
-        vocab_tuple = tuple(most_common_tokens)
+        counts = Counter(
+            token for sentence in sentences if sentence for token in sentence
+        ).most_common(max_size)
+        most_common = [token for token, _ in counts]
+        # forced_words first; dict.fromkeys preserves order and de-dupes
+        vocab_list = list(dict.fromkeys([*forced_words, *most_common]))
+        vocab_tuple = tuple(vocab_list)
         unknown_token_id = len(vocab_tuple)

Also applies to: 79-81

mo_net/train/trainer/trainer.py (2)

160-166: Silence unused-arg warnings in SIGINT handler.

Prevents ARG002 from Ruff.

 def _sigint_handler(self, signum: int, frame: Any) -> None:
     """Handle SIGINT (Ctrl+C) by setting interrupt flag."""
+    del signum, frame
     self._interrupt_requested = True
     self._logger.info(
         "SIGINT received. Training will be interrupted at the next safe point."
     )

178-206: Don’t block on interactive prompt in non‑TTY/CI; auto‑continue.

InquirerPy will fail or hang without a TTY. Gate on stdin.isatty().

     def _handle_interrupt(self) -> TrainingResult | None:
         """Handle interrupt request by prompting user for action."""
         if not self._interrupt_requested:
             return None
 
         self._interrupt_requested = False  # Reset flag
 
         self._logger.info("Training interrupted by user. Prompting for action...")
 
-        try:
+        try:
+            import sys
+            if not sys.stdin or not sys.stdin.isatty():
+                self._logger.warning("No interactive TTY detected; continuing training.")
+                return None
             choice = inquirer.select(
                 message="Training has been interrupted. What would you like to do?",

If preferred, add a constructor flag (e.g., interactive_prompts: bool = True) and disable prompts when False.

mo_net/samples/word2vec/__main__.py (1)

204-209: Remove unused # noqa: F821.

Hidden | HiddenLayer are imported; the directive is unnecessary.

-        hidden: Sequence[Hidden | HiddenLayer],  # noqa: F821
+        hidden: Sequence[Hidden | HiddenLayer],
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 86b4689 and b82d25e.

📒 Files selected for processing (3)
  • mo_net/samples/word2vec/__main__.py (1 hunks)
  • mo_net/samples/word2vec/vocab.py (1 hunks)
  • mo_net/train/trainer/trainer.py (15 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
mo_net/train/trainer/trainer.py (5)
mo_net/optimiser/base.py (12)
  • config (96-97)
  • Base (18-104)
  • set_model (84-85)
  • restore (90-93)
  • snapshot (87-88)
  • learning_rate (82-82)
  • learning_rate (123-124)
  • report (78-78)
  • report (119-120)
  • training_step (27-32)
  • training_step (35-42)
  • training_step (44-72)
mo_net/optimiser/adam.py (5)
  • AdaM (29-115)
  • restore (104-115)
  • snapshot (95-102)
  • learning_rate (89-90)
  • report (92-93)
mo_net/optimiser/rmsprop.py (5)
  • RMSProp (27-96)
  • restore (87-96)
  • snapshot (81-85)
  • learning_rate (75-76)
  • report (78-79)
mo_net/optimiser/scheduler.py (2)
  • CosineScheduler (72-120)
  • WarmupScheduler (123-179)
mo_net/train/run.py (1)
  • log_iteration (33-49)
mo_net/samples/word2vec/vocab.py (1)
mo_net/resources.py (1)
  • get_resource (16-47)
mo_net/samples/word2vec/__main__.py (12)
mo_net/__init__.py (1)
  • print_device_info (139-145)
mo_net/functions.py (1)
  • sparse_cross_entropy (20-23)
mo_net/log.py (2)
  • LogLevel (12-18)
  • setup_logging (21-23)
mo_net/model/layer/average.py (2)
  • Average (12-92)
  • axis (84-85)
mo_net/model/layer/embedding.py (2)
  • Embedding (117-307)
  • vocab_size (306-307)
mo_net/model/layer/linear.py (1)
  • Linear (182-395)
mo_net/model/model.py (6)
  • output (435-436)
  • Model (53-436)
  • hidden_modules (276-277)
  • output_module (290-291)
  • input_layer (286-287)
  • loss_contributors (269-270)
mo_net/model/layer/output.py (15)
  • OutputLayer (22-52)
  • SparseCategoricalSoftmaxOutputLayer (125-201)
  • Serialized (57-69)
  • Serialized (104-114)
  • Serialized (127-139)
  • output_dimensions (91-92)
  • output_dimensions (197-198)
  • serialize (52-52)
  • serialize (94-95)
  • serialize (200-201)
  • deserialize (60-69)
  • deserialize (107-114)
  • deserialize (130-139)
  • backward_prop (41-42)
  • backward_prop_with_negative (151-157)
mo_net/regulariser/weight_decay.py (4)
  • EmbeddingWeightDecayRegulariser (71-112)
  • embedding_layer (68-68)
  • attach (47-63)
  • attach (101-112)
mo_net/samples/word2vec/vocab.py (6)
  • get_english_sentences (306-315)
  • get_training_set (142-161)
  • serialize (27-34)
  • deserialize (37-48)
  • from_bytes (129-139)
  • english_sentences (115-126)
mo_net/train/trainer/trainer.py (4)
  • train (291-348)
  • BasicTrainer (100-488)
  • TrainingFailed (40-44)
  • TrainingSuccessful (35-36)
mo_net/config.py (1)
  • TrainingParameters (8-51)
🪛 Ruff (0.12.2)
mo_net/train/trainer/trainer.py

160-160: Unused method argument: signum

(ARG002)


160-160: Unused method argument: frame

(ARG002)

mo_net/samples/word2vec/__main__.py

174-174: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


176-176: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


178-178: Prefer TypeError exception for invalid type

(TRY004)


178-178: Avoid specifying long messages outside the exception class

(TRY003)


204-204: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


344-344: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


346-346: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


348-348: Prefer TypeError exception for invalid type

(TRY004)


348-348: Avoid specifying long messages outside the exception class

(TRY003)


701-701: Avoid specifying long messages outside the exception class

(TRY003)


724-726: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


724-726: Avoid specifying long messages outside the exception class

(TRY003)


729-729: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


729-729: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: typecheck

cursor[bot]

This comment was marked as outdated.

@modiase modiase force-pushed the feature/skip-gram branch 8 times, most recently from 54183e9 to f8d0af8 Compare September 16, 2025 18:39
@modiase modiase merged commit 183d248 into main Sep 16, 2025
4 of 5 checks passed
@modiase modiase deleted the feature/skip-gram branch September 16, 2025 18:44
@coderabbitai coderabbitai bot mentioned this pull request Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant