Conversation
|
Warning Rate limit exceeded@modiase has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 16 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (7)
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughRenamed "optimizer" to British "optimiser" across CLI, trainer, optimiser modules and samples; added Word2Vec samples and vocab utilities; introduced EmbeddingWeightDecayRegulariser; changed Model.backward_prop to return D[Activations]; extended output layer backward to accept negative labels; minor layer error message and tooling updates. Changes
Sequence Diagram(s)High-level Word2Vec training (CBOW / SkipGram)sequenceDiagram
participant CLI
participant Vocab
participant Model as "Model (CBOW / SkipGram)"
participant Trainer
participant Optimiser as "Optimiser"
CLI->>Vocab: build or load vocab & tokenized data
CLI->>Model: create or load model (cbow/skipgram)
CLI->>Optimiser: get_optimiser(...)
CLI->>Trainer: start trainer with model + optimiser
loop each epoch
Trainer->>Model: forward_pass(X)
Model-->>Trainer: outputs / loss
Trainer->>Model: backward_prop(Y_true[, Y_negative])
Model-->>Trainer: gradients (D[Activations])
Trainer->>Optimiser: compute_update(gradients)
Optimiser->>Model: apply_update
end
Embedding weight-decay handler attachmentsequenceDiagram
participant Model
participant EmbeddingReg as "EmbeddingWeightDecayRegulariser"
participant Optimiser as "Optimiser"
Model->>EmbeddingReg: instantiate (embedding layer)
EmbeddingReg->>Optimiser: attach(after_compute_update handler)
loop training steps
Optimiser->>EmbeddingReg: after_compute_update(learning_rate)
EmbeddingReg->>Model: modify embedding gradients / compute reg loss
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
009fe24 to
7d329e6
Compare
Seems like a silly idea. It's always just going to produce nonsense
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
mo_net/samples/word2vec/__main__.py (1)
543-543: Verify model type compatibility in sample command.The
samplecommand explicitly loads aCBOWModel, which won't work for Skip-gram models. This should either support both model types or clearly indicate it only works with CBOW.- model = CBOWModel.load(model_path, training=False) + # Try to determine model type from the saved model + model = Model.load(model_path, training=False) + if not isinstance(model, (CBOWModel, SkipGramModel)): + raise click.ClickException(f"Model must be either CBOW or Skip-gram, got {type(model).__name__}")Would you like me to implement a more robust model type detection mechanism that could store the model type in metadata alongside the vocabulary?
🧹 Nitpick comments (5)
.gitignore (1)
40-40: Scope the DB ignore to avoid surprises.As written, this ignores any train.db at any depth. Prefer anchoring to the repo root, or ignore all .db files if multiple DBs may appear.
Root-anchored (conservative):
-train.db +/train.dbAlternative (broader):
-train.db +*.db.envrc (1)
1-2: Optionally load a local .env for parity with app configs.This helps devs who keep secrets/config in a .env (already gitignored).
PATH_add bin source_env_if_exists .envrc.local +dotenv_if_exists .envmo_net/samples/word2vec/__main__.py (3)
125-125: Remove unusednoqadirective.The
# noqa: F821comment is unnecessary sinceHiddenLayeris properly imported on line 20.- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
465-468: Consider more descriptive variable naming for reshaping operation.The reshaping logic for Skip-gram is correct, but the variable names could be clearer about what's being reshaped and why.
if model_type == "skipgram": + # Reshape to (batch_size, 1) for single word input in Skip-gram X_train_split = X_train_split.reshape(-1, 1) X_val = X_val.reshape(-1, 1)
507-508: Clean up unnecessary variable renaming.The variable
never_2appears to be an artifact of automatic refactoring. It should be simplified to justneverfor consistency.- case never_2: - assert_never(never_2) + case never: + assert_never(never)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
.envrc(1 hunks).gitignore(1 hunks)bin/typecheck(1 hunks)mo_net/samples/word2vec/__main__.py(11 hunks)
✅ Files skipped from review due to trivial changes (1)
- bin/typecheck
🧰 Additional context used
🧬 Code graph analysis (1)
mo_net/samples/word2vec/__main__.py (9)
mo_net/__init__.py (1)
print_device_info(139-145)mo_net/functions.py (1)
sparse_cross_entropy(20-23)mo_net/model/layer/embedding.py (2)
Embedding(117-307)vocab_size(306-307)mo_net/model/model.py (6)
output(435-436)Model(53-436)loss_contributors(269-270)output_module(290-291)hidden_modules(276-277)input_layer(286-287)mo_net/model/layer/output.py (4)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-181)backward_prop(41-42)backward_prop_with_negative(151-157)mo_net/regulariser/weight_decay.py (4)
EmbeddingWeightDecayRegulariser(71-112)embedding_layer(68-68)attach(47-63)attach(101-112)mo_net/samples/word2vec/vocab.py (3)
Vocab(26-121)get_training_set(124-143)english_sentences(115-121)mo_net/train/trainer/trainer.py (3)
train(209-266)BasicTrainer(98-401)get_optimiser(51-89)mo_net/config.py (1)
TrainingParameters(8-51)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py
125-125: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
383-383: Unused function argument: device
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Cursor Bugbot
- GitHub Check: typecheck
🔇 Additional comments (6)
.envrc (2)
1-2: LGTM: minimal, safe direnv setup.PATH_add bin and source_env_if_exists are standard and keep local overrides out of VCS.
1-2: Confirmed: bin/typecheck is present and executable.
Found bin/typecheck with permissions -rwxr-xr-x — no change required.mo_net/samples/word2vec/__main__.py (4)
121-195: LGTM! Well-structured Skip-gram model implementation.The
SkipGramModelclass is properly implemented with appropriate initialization, factory method, and configuration. The negative sampling approach is correctly integrated.
204-225: Solid negative sampling implementation in backward propagation.The backward propagation correctly implements negative sampling using the specialized
backward_prop_with_negativemethod. The random key management and gradient flow through layers are properly handled.
471-476: Good use of updated optimiser naming convention.The refactoring from "optimizer" to "optimiser" is consistently applied throughout the code, maintaining British English spelling consistency.
401-406: Skip-gram swap verified — no change required.
get_training_set(...) returns (context, target) as jnp arrays (context has shape (N, 2*context_size), target has shape (N,)). Initially X_train=context, Y_train=target; the swap makes X_train=target and Y_train=context, which matches Skip‑gram (input=target → predict context).
3218d0e to
5943931
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
mo_net/model/layer/output.py (1)
8-8: Import bug: abstractmethod imported from pyparsing
abstractmethodmust come fromabc, notpyparsing. This will raise at import time.-from pyparsing import abstractmethod +from abc import abstractmethod
♻️ Duplicate comments (3)
mo_net/model/layer/output.py (1)
171-176: Fix for prior negative-sample indexing bug looks correctUsing
repeat(arange(batch), num_neg)for rows and a flattenedY_negativefor columns addresses the earlier out‑of‑bounds issue whennegative_samples > 1. Good job.Also applies to: 188-191
mo_net/samples/word2vec/__main__.py (2)
383-383: Unused device parameter still needs addressing.This is a duplicate of a previous review comment. The
deviceparameter is still unused in the function implementation.
469-469: Hardcoded run name still doesn't reflect model type.This is a duplicate of a previous review comment. The run name is still hardcoded to "cbow_run_" regardless of the actual model type being trained.
🧹 Nitpick comments (5)
mo_net/model/layer/output.py (4)
165-167: Shorten error message to satisfy TRY003 (Ruff)Message is flagged as “long message outside exception class.” A concise message is sufficient here.
- if (output_activations := self._cache["output_activations"]) is None: - raise ValueError("Output activations not set during forward pass.") + if (output_activations := self._cache["output_activations"]) is None: + raise ValueError("Forward pass required.")
168-170: Unnecessary copy;.at[...]is functional already
jnp.ndarray.atreturns a new array; copying first just adds an allocation.- result = output_activations.copy() - result = result.at[jnp.arange(Y_true.shape[0]), Y_true].add(-1.0) + batch_size = output_activations.shape[0] + result = output_activations.at[jnp.arange(batch_size), Y_true].add(-1.0)
171-193: Minor shape/dtype validations and TRY003 nits
- Consider validating
Y_true.ndim == 1(sparse targets) to fail fast.- Add a brief dtype check to ensure
Y_negativeis integer to avoid cryptic JAX errors.- Shorten two error messages to satisfy Ruff TRY003.
if Y_negative is not None: + # Validate dtypes/shapes early + if Y_true.ndim != 1: + raise ValueError("Y_true must be 1D indices.") + if not jnp.issubdtype(Y_negative.dtype, jnp.integer): + raise ValueError("Y_negative dtype must be integer.") if Y_negative.ndim == 2: batch_size, num_neg = Y_negative.shape row_idx = jnp.repeat(jnp.arange(batch_size), num_neg) col_idx = Y_negative.reshape(-1) result = result.at[row_idx, col_idx].add(1.0) elif Y_negative.ndim == 1: # - (batch_size,) for a single negative per example - batch_size = result.shape[0] + batch_size = result.shape[0] total = Y_negative.shape[0] if total == batch_size: result = result.at[jnp.arange(batch_size), Y_negative].add(1.0) else: if total % batch_size != 0: - raise ValueError( - "Y_negative length must be a multiple of batch size when flattened." - ) + raise ValueError("Invalid Y_negative length (must be multiple of batch size).") num_neg = total // batch_size row_idx = jnp.repeat(jnp.arange(batch_size), num_neg) result = result.at[row_idx, Y_negative].add(1.0) else: - raise ValueError("Y_negative must be 1D or 2D array of indices.") + raise ValueError("Y_negative must be 1D or 2D.")
194-194: Cast is fine; consider droppingatleast_1d
resultis already 2D;jnp.atleast_1dis a no‑op. You can returnresultdirectly and keep thecast.- return cast(D[Activations], jnp.atleast_1d(result)) + return cast(D[Activations], result)mo_net/samples/word2vec/__main__.py (1)
125-125: Remove unused noqa directive.The
# noqa: F821comment is no longer needed since the type annotation issue has been resolved.- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
mo_net/model/layer/output.py(1 hunks)mo_net/samples/word2vec/__main__.py(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/model/layer/output.py (1)
mo_net/protos.py (1)
D(18-26)
mo_net/samples/word2vec/__main__.py (7)
mo_net/__init__.py (1)
print_device_info(139-145)mo_net/functions.py (1)
sparse_cross_entropy(20-23)mo_net/model/model.py (4)
output(435-436)Model(53-436)loss_contributors(269-270)output_module(290-291)mo_net/model/layer/output.py (4)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-201)backward_prop(41-42)backward_prop_with_negative(151-157)mo_net/regulariser/weight_decay.py (4)
EmbeddingWeightDecayRegulariser(71-112)embedding_layer(68-68)attach(47-63)attach(101-112)mo_net/samples/word2vec/vocab.py (3)
Vocab(26-121)get_training_set(124-143)english_sentences(115-121)mo_net/train/trainer/trainer.py (2)
train(209-266)get_optimiser(51-89)
🪛 Ruff (0.12.2)
mo_net/model/layer/output.py
166-166: Avoid specifying long messages outside the exception class
(TRY003)
185-187: Avoid specifying long messages outside the exception class
(TRY003)
192-192: Avoid specifying long messages outside the exception class
(TRY003)
mo_net/samples/word2vec/__main__.py
125-125: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
383-383: Unused function argument: device
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (8)
mo_net/model/layer/output.py (1)
151-158: API addition looks good; keep it discoverablebackward_prop_with_negative simply forwards to _backward_prop and is exercised at mo_net/samples/word2vec/main.py:210 — no further changes required.
mo_net/samples/word2vec/__main__.py (7)
7-7: LGTM! Good addition of type imports for SkipGram implementation.The additional imports (
Literal,assert_never,cast) are appropriately used throughout the new SkipGram implementation.
15-15: LGTM! Import updates align with the vocabulary refactoring and new features.The import changes correctly reflect:
- Moving vocabulary utilities to external modules
- Adding support for output layers with negative sampling
- Updating to use the modernized "optimiser" spelling
- Adding necessary type imports for the new SkipGram model
Also applies to: 17-17, 20-20, 23-23, 26-28
121-233: LGTM! Well-implemented SkipGram model with proper negative sampling.The SkipGramModel implementation correctly:
- Uses single-word input (input_dimensions=(1,)) vs CBOW's context window
- Implements negative sampling in
backward_propwith proper tensor reshaping- Overrides
compute_lossto handle flattened targets appropriately- Maintains the same embedding layer interface for consistency
The negative sampling logic properly generates random negative samples and passes them to the output layer's
backward_prop_with_negativemethod.
279-369: LGTM! Training options correctly expanded for SkipGram support.The CLI options have been appropriately updated to support both CBOW and SkipGram models:
- Added
--model-typewith choices for "cbow" and "skipgram"- Added
--negative-samplesfor SkipGram configuration- Added
--model-pathfor loading existing models- Updated parameter names and defaults appropriately
377-509: LGTM! Train function properly handles both model types.The training function correctly:
- Uses external vocabulary utilities from
vocab.py- Swaps X_train/Y_train for SkipGram models (lines 404-405)
- Creates appropriate model instances based on
model_type- Reshapes input data for SkipGram (lines 465-467)
- Uses modernized "optimiser" spelling consistently
The model loading and creation logic is sound, and the data preparation handles the different input/output requirements of CBOW vs SkipGram appropriately.
470-476: LGTM! Proper optimiser modernization and type casting.The code correctly:
- Uses the modernized
get_optimiserfunction name- Casts the model to the union type for the regulariser attachment
- Maintains backward compatibility with existing CBOW models
401-401: Vocab.english_sentences integration verified — no changes required. The classmethod exists at mo_net/samples/word2vec/vocab.py, accepts max_vocab_size as a keyword, and delegates to get_english_sentences(limit) → from_sentences(...), which builds the most-common-token vocab and returns tokenized sentences; get_training_set returns (jnp.ndarray, jnp.ndarray).
5943931 to
2c9926c
Compare
- Update import path for check_device and DeviceType in word2vec
9cbb730 to
1e515cc
Compare
modiase
left a comment
There was a problem hiding this comment.
@modiase reviewed 12 of 18 files at r1, 4 of 4 files at r2, 5 of 11 files at r3, 11 of 11 files at r5, 12 of 12 files at r7, all commit messages.
Reviewable status:complete! all files reviewed, all discussions resolved (waiting on @modiase)
1e515cc to
05b8835
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
mo_net/model/layer/output.py (1)
8-8: Blocking: wrong abstractmethod import.This module imports abstractmethod from pyparsing, which will fail. It should come from abc.
Use:
from abc import abstractmethod
♻️ Duplicate comments (3)
flake.nix (1)
100-103: Remove hard‑coded Nix store driver path from LD_LIBRARY_PATH; use derivation path instead.The pinned store path is non‑portable and will rot on driver/kernel updates. This was already flagged earlier. Use
nvidia_x11’s derivation path and gate to Linux.- # Set up library paths with NVIDIA driver libraries prioritized for JAX - # Use the specific NVIDIA driver version that matches the kernel (570.153.02) - export LD_LIBRARY_PATH="/nix/store/6rf8qnkp4m5h5x9byf5srphcdjdp5r5j-nvidia-x11-570.153.02-6.12.35/lib:${pkgs.lib.makeLibraryPath (systemLibs ++ cudaLibs)}:''${LD_LIBRARY_PATH:-}" + # Set up library paths with NVIDIA driver libraries prioritized for JAX (Linux only) + export LD_LIBRARY_PATH="${ + pkgs.lib.optionalString pkgs.stdenv.isLinux "${pkgs.lib.makeLibraryPath [ pkgs.linuxPackages.nvidia_x11 ]}:" + }${pkgs.lib.makeLibraryPath (systemLibs ++ cudaLibs)}:''${LD_LIBRARY_PATH:-}"mo_net/model/layer/output.py (1)
171-176: Previous negative-sample indexing bug is fixed.Using jnp.repeat(jnp.arange(batch_size), num_neg) resolves the earlier out‑of‑bounds issue when negatives > 1. Nice.
mo_net/samples/word2vec/__main__.py (1)
462-471: LGTM!Good updates to use the model type in the training run name and proper casting for the regularizer attachment.
Note: This addresses the previous review comment about the hardcoded "cbow_run_" name by using the dynamic
model_typevariable.
🧹 Nitpick comments (2)
mo_net/model/layer/output.py (2)
159-165: Minor: shorten error strings (Ruff TRY003) or silence per-line.If you don’t want custom exception types, keep messages terse and add “# noqa: TRY003” where needed (as shown in the diff above).
159-165: Fix signature mismatch in OutputLayer._backward_propAbstract OutputLayer._backward_prop is declared without Y_negative (mo_net/model/layer/output.py:45) but the concrete override adds Y_negative (mo_net/model/layer/output.py:159) — static type checkers (mypy/pyright) will flag this as an incompatible override.
Options:
- Preferred: keep the abstract signature unchanged and add a private helper (e.g., _backward_prop_with_negative) used by the method that accepts Y_negative.
- Alternative: widen the abstract signature to accept Y_negative: jnp.ndarray | None = None and update all overrides.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flake.nix(2 hunks)mo_net/model/layer/output.py(1 hunks)mo_net/samples/word2vec/__main__.py(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/samples/word2vec/__main__.py (6)
mo_net/model/layer/embedding.py (2)
Embedding(117-307)vocab_size(306-307)mo_net/model/model.py (5)
output(435-436)Model(53-436)loss_contributors(269-270)output_module(290-291)hidden_modules(276-277)mo_net/model/layer/output.py (4)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-201)backward_prop(41-42)backward_prop_with_negative(151-157)mo_net/regulariser/weight_decay.py (4)
EmbeddingWeightDecayRegulariser(71-112)embedding_layer(68-68)attach(47-63)attach(101-112)mo_net/samples/word2vec/vocab.py (3)
Vocab(26-121)get_training_set(124-143)english_sentences(115-121)mo_net/train/trainer/trainer.py (2)
train(209-266)get_optimiser(51-89)
mo_net/model/layer/output.py (2)
mo_net/protos.py (1)
D(18-26)mo_net/model/layer/base.py (1)
_backward_prop(107-111)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py
125-125: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
mo_net/model/layer/output.py
166-166: Avoid specifying long messages outside the exception class
(TRY003)
185-187: Avoid specifying long messages outside the exception class
(TRY003)
192-192: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: typecheck
- GitHub Check: Cursor Bugbot
🔇 Additional comments (16)
mo_net/model/layer/output.py (1)
151-158: Document shapes for backward_prop_with_negativeAdd a concise docstring on mo_net/model/layer/output.py::backward_prop_with_negative describing expected shapes:
- Y_true: jnp.ndarray[int], shape (N,) — N = number of targets (e.g. batch or batch*context when flattened).
- Y_negative: jnp.ndarray[int], shape (N, K) — K negative samples per target (or documented as flattened (N*K,) if callers use that).
Caller evidence: mo_net/samples/word2vec/main.py:210 calls backward_prop_with_negative(Y_true=Y_true.flatten(), Y_negative=jax.random.choice(...)). Keep the original backward_prop for existing callers.
mo_net/samples/word2vec/__main__.py (15)
7-7: LGTM!Good addition of the
assert_neverimport to handle exhaustive pattern matching for the new model type selection.
17-17: LGTM!Correct import of
LossFntype for the newcompute_lossmethod signature.
20-20: LGTM!Smart aliasing of
HiddenasHiddenLayerfor compatibility with the newSkipGramModelconstructor.
23-23: LGTM!Good import addition of
OutputLayerfor the new backward propagation functionality.
26-28: LGTM!Excellent modularization by extracting vocabulary utilities to a separate module and importing the regularizer for weight decay functionality.
36-36: LGTM!Good update to use the standardized British spelling of "optimiser" for consistency across the codebase.
279-363: LGTM!The CLI option additions are well-structured and comprehensive. The new options for
--model-type,--negative-samples, and other parameters properly support both CBOW and Skip-gram training workflows.
371-388: LGTM!The function signature properly incorporates all the new parameters needed for both model types, including the
model_typediscriminator andnegative_samplesfor skip-gram training.
394-398: LGTM!Smart data preparation logic that correctly swaps X and Y for skip-gram training. This aligns with the skip-gram model where you predict context words from a center word, opposite to CBOW.
409-430: LGTM!Excellent model creation logic using pattern matching to handle both model types. The parameters are correctly passed to each model's factory method.
458-461: LGTM!Correct data reshaping for skip-gram input format, ensuring the input has the expected shape of
(-1, 1)for single-word inputs.
479-479: LGTM!Consistent use of British spelling "optimiser" throughout the training setup.
486-488: LGTM!Good logging message that reflects the actual model type being trained.
502-503: LGTM!Proper use of
assert_neverwith a unique variable name to handle exhaustive pattern matching.
121-234: Approve — confirm removal ofnoqa
- Remove the unused
# noqa: F821on Line 125 only after confirmingHidden/HiddenLayerare defined or imported (quick repo search didn’t locate their definitions).- Negative sampling: Verified — negatives are sampled as batch_size * context_size * negative_samples (k negatives per positive) and passed to the output layer; aligns with expected skip‑gram negative sampling.
- Model architecture: Verified —
input_dimensions=(1,)is correct for skip‑gram; training flips X/Y for skipgram and reshapes inputs to (-1, 1).- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
2577878 to
0315e1f
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (6)
mo_net/samples/word2vec/__main__.py (6)
130-139: Remove unnecessary noqa and tighten type hint.
# noqa: F821is unnecessary becauseHiddenLayeris imported. Clean it up.- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
417-435: Unifytokenized_sentencestype across branches to satisfy the type checker.CI error at Line 417 shows mismatch (
Collection[Sequence[int]]vslist[list[int]]). The annotation in the previous comment resolves it. If you prefer a concrete type, convertVocab.english_sentences(...)[1]tolist[list[int]].- vocab, tokenized_sentences = Vocab.english_sentences(max_vocab_size=vocab_size) + vocab, tokenized = Vocab.english_sentences(max_vocab_size=vocab_size) + tokenized_sentences: Collection[Sequence[int]] = tokenized
480-486: Regulariser attach relies on subclass typing.This will fail when loading a base
Model. After applying the subclass‑aware load fix above, this cast is correct. Consider annotatingmodel: CBOWModel | SkipGramModelwhen created/loaded to help type checkers.- optimiser = get_optimiser("adam", model, training_parameters) + model = cast(CBOWModel | SkipGramModel, model) + optimiser = get_optimiser("adam", model, training_parameters)
564-586: Exception chaining and logging fix insample; also satisfy type checker.Chain exceptions to preserve cause (B904) and declare the union type for
modelto avoid the mypy error at Line 576. Also avoid passingetologger.exceptionunnecessarily.- try: + try: with zipfile.ZipFile(model_path, "r") as zf: vocab_bytes = zf.read(VOCAB_ZIP_INTERNAL_PATH) with zf.open(METADATA_ZIP_INTERNAL_PATH) as md: metadata = json.loads(md.read().decode("utf-8")) model_type = metadata.get("type", "cbow") - with zf.open(MODEL_ZIP_INTERNAL_PATH) as mf: + model: CBOWModel | SkipGramModel + with zf.open(MODEL_ZIP_INTERNAL_PATH) as mf: match model_type: case "skipgram": model = SkipGramModel.load(mf, training=False) case "cbow": model = CBOWModel.load(mf, training=False) case never: assert_never(never) except KeyError as e: - raise click.ClickException( + raise click.ClickException( f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}" - ) + ) from e except zipfile.BadZipFile as e: - logger.exception(f"Invalid zip file: {model_path}", e) - raise click.ClickException(f"Invalid zip file: {model_path}") + logger.exception(f"Invalid zip file: {model_path}") + raise click.ClickException(f"Invalid zip file: {model_path}") from e
589-597: Minor efficiency nit in sampling.
list(vocab.vocab)is created repeatedly; cache once for readability (micro‑opt).- random_words = [list(vocab.vocab)[int(i)] for i in word_indices] + vocab_list = list(vocab.vocab) + random_words = [vocab_list[int(i)] for i in word_indices]
217-234: Optional: avoid sampling positives as negatives.Uniform negatives can include true labels; usually OK, but you can resample to exclude positives to marginally stabilise training.
Would you like a small helper to draw negatives excluding
Y_trueper row?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
mo_net/samples/word2vec/__main__.py(10 hunks)mo_net/samples/word2vec/vocab.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- mo_net/samples/word2vec/vocab.py
🧰 Additional context used
🧬 Code graph analysis (1)
mo_net/samples/word2vec/__main__.py (4)
mo_net/model/model.py (4)
output(435-436)output_module(290-291)hidden_modules(276-277)input_layer(286-287)mo_net/model/layer/output.py (7)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-201)backward_prop(41-42)backward_prop_with_negative(151-157)serialize(52-52)serialize(94-95)serialize(200-201)mo_net/regulariser/weight_decay.py (4)
EmbeddingWeightDecayRegulariser(71-112)embedding_layer(68-68)attach(47-63)attach(101-112)mo_net/samples/word2vec/vocab.py (6)
Vocab(21-129)get_english_sentences(296-305)get_training_set(132-151)from_bytes(119-129)english_sentences(110-116)serialize(27-34)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py
134-134: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
580-582: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
580-582: Avoid specifying long messages outside the exception class
(TRY003)
585-585: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
585-585: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: CI
mo_net/samples/word2vec/__main__.py
[error] 417-417: uv run task typecheck -v: Incompatible types in assignment (expression has type "Collection[Sequence[int]]", variable has type "list[list[int]]") [assignment]
[error] 576-576: uv run task typecheck -v: Incompatible types in assignment (expression has type "CBOWModel", variable has type "SkipGramModel") [assignment]
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (4)
mo_net/samples/word2vec/__main__.py (4)
473-476: Reshape centers only for Skip‑gram: LGTM.
This ensures X has shape(batch, 1)for the embedding layer. No action needed.
501-503: Nice training log message.
Clear and parameterised bymodel_type.
509-526: ZIP save flow: LGTM.
Buffered write, deterministic internal paths, and ensuring parent dirs exist are good practices.
321-349: CLI options: good coverage.
Options for optimiser regularisation, negatives, and logging are clear. No changes required.
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (8)
mo_net/samples/word2vec/vocab.py (3)
63-93: Token normalisation is missing when building the vocab and tokenization.
from_sentencescounts tokens and maps them without callingclean_token. As a result, punctuation/case variants inflate the vocabulary and “The” won’t be filtered by lower‑case stopwords.Apply:
- most_common_tokens = [ - token - for token, _ in Counter( - token for sentence in sentences if sentence for token in sentence - ).most_common(max_size) - ] + most_common_tokens = [ + token + for token, _ in Counter( + t + for sentence in sentences + if sentence + for raw in sentence + if (t := cls.clean_token(raw)) + ).most_common(max_size) + ] @@ - [ - [vocab[token] for token in sentence] - for sentence in sentences - if sentence - ], + [ + [vocab[cls.clean_token(raw)] for raw in sentence if cls.clean_token(raw)] + for sentence in sentences + if sentence + ],
95-101: Clarify length semantics.
__len__returnslen(self.vocab) + 1(to include the unknown token). This is fine, but it’s surprising. Add a short docstring to avoid misuse where callers iteraterange(len(vocab))expecting only known tokens.- def __len__(self) -> int: - return len(self.vocab) + 1 + def __len__(self) -> int: + """Vocabulary size including <unknown>.""" + return len(self.vocab) + 1
132-152: Memory pressure: build context/target lazily or preallocate.The list comprehension materialises all pairs before zipping; on large corpora this is heavy.
Example generator-based approach:
- context, target = zip( - *[ - ( - tuple( - chain( - sentence[i - context_size : i], - sentence[i + 1 : i + context_size + 1], - ) - ), - sentence[i], - ) - for sentence in tokenized_sentences - for i in range(context_size, len(sentence) - context_size) - ], - strict=True, - ) + pairs = ( + ( + tuple( + chain( + sentence[i - context_size : i], + sentence[i + 1 : i + context_size + 1], + ) + ), + sentence[i], + ) + for sentence in tokenized_sentences + for i in range(context_size, len(sentence) - context_size) + ) + context, target = zip(*pairs, strict=True)mo_net/samples/word2vec/__main__.py (5)
134-134: Remove unused# noqa: F821.
HiddenLayeris imported (Line 24), so the noqa is unnecessary and flagged by Ruff.- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
205-212: Ensure scalar loss type.Depending on
loss_fn,compute_lossmay return a JAX array; downstream code often expects a Pythonfloat.- return loss_fn(Y_pred, Y_true.flatten()) + sum( + return float(loss_fn(Y_pred, Y_true.flatten())) + sum( contributor() for contributor in self.loss_contributors )
439-441: Double-check CBOW/Skip-gram train set swap.The swap is correct for Skip‑gram. Add a brief comment so future edits don’t regress this.
- if model_type == "skipgram": + if model_type == "skipgram": + # For Skip-gram, inputs are centers and labels are contexts. Y_train, X_train = X_train, Y_train
473-476: Reshape only when needed and keep dtypes.The reshape for Skip‑gram inputs is required (centers as (B,1)). Confirm
X_train.dtypeis integer; cast if needed to avoid JAX implicit promotion.- if model_type == "skipgram": - X_train_split = X_train_split.reshape(-1, 1) - X_val = X_val.reshape(-1, 1) + if model_type == "skipgram": + X_train_split = jnp.asarray(X_train_split, dtype=jnp.int32).reshape(-1, 1) + X_val = jnp.asarray(X_val, dtype=jnp.int32).reshape(-1, 1)
580-585: Chain exceptions and fix logger.exception usage.Use
raise ... from e(B904) and don’t passeas a positional arg tologger.exception(it’s implicit).- except KeyError as e: + except KeyError as e: raise click.ClickException( f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}" - ) - except zipfile.BadZipFile as e: - logger.exception(f"Invalid zip file: {model_path}", e) - raise click.ClickException(f"Invalid zip file: {model_path}") + ) from e + except zipfile.BadZipFile as e: + logger.exception("Invalid zip file: {}", model_path) + raise click.ClickException(f"Invalid zip file: {model_path}") from e
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
mo_net/samples/word2vec/__main__.py(10 hunks)mo_net/samples/word2vec/vocab.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/samples/word2vec/__main__.py (4)
mo_net/model/model.py (4)
output(435-436)output_module(290-291)hidden_modules(276-277)input_layer(286-287)mo_net/model/layer/output.py (7)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-201)backward_prop(41-42)backward_prop_with_negative(151-157)serialize(52-52)serialize(94-95)serialize(200-201)mo_net/regulariser/weight_decay.py (4)
EmbeddingWeightDecayRegulariser(71-112)embedding_layer(68-68)attach(47-63)attach(101-112)mo_net/samples/word2vec/vocab.py (6)
Vocab(21-129)get_english_sentences(296-305)get_training_set(132-151)from_bytes(119-129)english_sentences(110-116)serialize(27-34)
mo_net/samples/word2vec/vocab.py (1)
mo_net/resources.py (1)
get_resource(16-47)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py
134-134: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
580-582: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
580-582: Avoid specifying long messages outside the exception class
(TRY003)
585-585: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
585-585: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: CI
mo_net/samples/word2vec/__main__.py
[error] 417-417: Mypy: Incompatible types in assignment. Expression has type 'Collection[Sequence[int]]', variable has type 'list[list[int]]' (assignment). Step: 'uv run task typecheck -v'.
[error] 576-576: Mypy: Incompatible types in assignment. Expression has type 'CBOWModel', variable has type 'SkipGramModel' (assignment). Step: 'uv run task typecheck -v'.
🔇 Additional comments (3)
mo_net/samples/word2vec/__main__.py (3)
477-486: OK: run naming and regulariser wiring.Run names reflect model type; weight‑decay regulariser is correctly attached to the embedding layer.
Consider persisting
negative_samplesin metadata for reproducibility.
501-531: Good: self-contained zip artefact.Writing model.pkl, vocab.msgpack and metadata.json into a single zip is clean and reproducible; defaulting output path to DATA_DIR is sensible.
213-226: Incorrect — no shape mismatch; labels and outputs are flattened consistently.forward_prop returns predictions flattened across context positions (compute_loss calls Y_pred = self.forward_prop(X) and compares to Y_true.flatten()); backward_prop passes Y_true.flatten() and flattened negatives, and SparseCategoricalSoftmaxOutputLayer._backward_prop indexes rows by Y_true.shape[0] against the cached output_activations, so the shapes align.
Likely an incorrect or invalid review comment.
5f6bca3 to
4cdfe81
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (4)
mo_net/samples/word2vec/vocab.py (2)
16-17: Avoid Python 3.12-onlytypealiases; useTypeAlias.Apply:
-type Sentence = Sequence[str] -type TokenizedSentence = Sequence[int] +Sentence: TypeAlias = Sequence[str] +TokenizedSentence: TypeAlias = Sequence[int]
296-305: Normalize tokens before stopword filtering.Apply:
-def get_english_sentences(limit: int = 100000) -> Collection[Sentence]: +def get_english_sentences(limit: int = 100000) -> Collection[Sentence]: stop_words = get_stop_words() return [ - [word for word in sentence.split() if word not in stop_words] + [ + tok + for tok in (Vocab.clean_token(w) for w in sentence.split()) + if tok and tok not in stop_words + ] for sentence in ( get_resource("s3://mo-net-resources/english-sentences.txt") .read_text() .split("\n")[:limit] ) ]mo_net/samples/word2vec/__main__.py (2)
275-281: Skip‑gram loss: shape mismatch between logits and flattened labels.
Y_true.flatten()has lengthbatch_size*num_posbutY_predis(batch_size, vocab). Repeat logits per positive target before loss.Apply:
def compute_loss( self, X: jnp.ndarray, Y_true: jnp.ndarray, loss_fn: LossFn ) -> float: - Y_pred = self.forward_prop(X) - return loss_fn(Y_pred, Y_true.flatten()) + sum( + Y_pred = self.forward_prop(X) + if Y_true.ndim == 2: + num_pos = Y_true.shape[1] + Y_pred = jnp.repeat(Y_pred, repeats=num_pos, axis=0) + y = Y_true.reshape(-1) + else: + y = Y_true + return loss_fn(Y_pred, y) + sum( contributor() for contributor in self.loss_contributors )
283-303: Skip‑gram backprop: row indexing overruns due to flattened labels.Output-layer gradient uses
jnp.arange(Y_true.shape[0]). Flattening makes rows exceed logits rows. Aggregate per positive after repeating.Apply:
- def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]: - batch_size, context_size = Y_true.shape - self._key, subkey = jax.random.split(self._key) - - dZ = cast( - SparseCategoricalSoftmaxOutputLayer, self.output.output_layer - ).backward_prop_with_negative( - Y_true=Y_true.flatten(), - Y_negative=jax.random.choice( - subkey, - self.embedding_layer.vocab_size, - shape=(batch_size * context_size * self._negative_samples,), - ), - ) + def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]: + if Y_true.ndim == 2: + batch_size, num_pos = Y_true.shape + self._key, subkey = jax.random.split(self._key) + dZ_rep = cast( + SparseCategoricalSoftmaxOutputLayer, self.output.output_layer + ).backward_prop_with_negative( + Y_true=Y_true.reshape(-1), + Y_negative=jax.random.choice( + subkey, + self.embedding_layer.vocab_size, + shape=(batch_size * num_pos * self._negative_samples,), + ), + ) + # Sum per example across its positive contexts → shape (batch_size, vocab) + dZ = dZ_rep.reshape(batch_size, num_pos, -1).sum(axis=1) + else: + batch_size = Y_true.shape[0] + self._key, subkey = jax.random.split(self._key) + dZ = cast( + SparseCategoricalSoftmaxOutputLayer, self.output.output_layer + ).backward_prop_with_negative( + Y_true=Y_true, + Y_negative=jax.random.choice( + subkey, + self.embedding_layer.vocab_size, + shape=(batch_size * self._negative_samples,), + ), + )
🧹 Nitpick comments (7)
bin/mnist (1)
1-2: Use exec, quote paths, and prefer module invocation (-m).
- Replace path execution with
python -m mo_net.samples.mnistto avoid relying ongitand path spaces. Also useexecso the wrapper forwards signals/exit code.Apply:
-uv run python $(git rev-parse --show-toplevel)/mo_net/samples/mnist "$@" +exec uv run python -m mo_net.samples.mnist "$@"mo_net/samples/word2vec/vocab.py (3)
38-40: Msgpack decoding: ensure string keys/values are decoded.Be explicit with
raw=Falseto avoidbyteswhen unpacking across environments.Apply:
- data = msgpack.unpackb(f.read()) + data = msgpack.unpackb(f.read(), raw=False)
119-121: Msgpack decoding: be explicit here as well.Apply:
- obj = msgpack.unpackb(data) + obj = msgpack.unpackb(data, raw=False)
132-151: get_training_set: empty-corpus safety (zip strict=True raises).For short corpora or large
context_size, the list may be empty andzip(..., strict=True)raisesValueError. Guard for empty.Apply:
-def get_training_set( +def get_training_set( tokenized_sentences: Collection[TokenizedSentence], context_size: int ) -> tuple[jnp.ndarray, jnp.ndarray]: - context, target = zip( - *[ - ( - tuple( - chain( - sentence[i - context_size : i], - sentence[i + 1 : i + context_size + 1], - ) - ), - sentence[i], - ) - for sentence in tokenized_sentences - for i in range(context_size, len(sentence) - context_size) - ], - strict=True, - ) - return jnp.array(context), jnp.array(list(target)) + pairs = [ + ( + tuple( + chain( + sentence[i - context_size : i], + sentence[i + 1 : i + context_size + 1], + ) + ), + sentence[i], + ) + for sentence in tokenized_sentences + for i in range(context_size, len(sentence) - context_size) + ] + if not pairs: + return jnp.empty((0, context_size * 2), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32) + context, target = zip(*pairs) + return jnp.array(context), jnp.array(target)mo_net/samples/word2vec/__main__.py (3)
204-204: Remove unused noqa.
# noqa: F821is unnecessary here; imports satisfy the name.Apply:
- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
713-719: Exception chaining and logging.Use
raise ... from eand avoid passing exception objects tologger.exception(it logs the active exception automatically).Apply:
- except KeyError as e: - raise click.ClickException( - f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}" - ) - except zipfile.BadZipFile as e: - logger.exception(f"Invalid zip file: {model_path}", e) - raise click.ClickException(f"Invalid zip file: {model_path}") + except KeyError as e: + raise click.ClickException( + f"Missing file in zip: {e.args[0]}. Expected {MODEL_ZIP_INTERNAL_PATH}, {VOCAB_ZIP_INTERNAL_PATH} and optionally {METADATA_ZIP_INTERNAL_PATH}" + ) from e + except zipfile.BadZipFile as e: + logger.exception(f"Invalid zip file: {model_path}") + raise click.ClickException(f"Invalid zip file: {model_path}") from e
739-751: Cast JAX scalars to float for sorting/formatting.
jnp.arrayscalars don’t always support Python float formatting and comparisons for sorting; convert tofloat.Apply:
- similarities = [] + similarities: list[tuple[str, float]] = [] for other_word in vocab.vocab: if other_word != word: other_id = vocab[other_word] other_embedding = model.embeddings[other_id] - similarity = jnp.dot(word_embedding, other_embedding) / ( + similarity = jnp.dot(word_embedding, other_embedding) / ( jnp.linalg.norm(word_embedding) * jnp.linalg.norm(other_embedding) ) - similarities.append((other_word, similarity)) + similarities.append((other_word, float(similarity)))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
bin/mnist(1 hunks)bin/word2vec(1 hunks)mo_net/samples/word2vec/__main__.py(1 hunks)mo_net/samples/word2vec/vocab.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- bin/word2vec
🧰 Additional context used
🧬 Code graph analysis (2)
mo_net/samples/word2vec/vocab.py (1)
mo_net/resources.py (1)
get_resource(16-47)
mo_net/samples/word2vec/__main__.py (10)
mo_net/__init__.py (1)
print_device_info(139-145)mo_net/functions.py (1)
sparse_cross_entropy(20-23)mo_net/log.py (2)
LogLevel(12-18)setup_logging(21-23)mo_net/model/layer/embedding.py (2)
Embedding(117-307)vocab_size(306-307)mo_net/model/model.py (6)
output(435-436)Model(53-436)hidden_modules(276-277)output_module(290-291)input_layer(286-287)loss_contributors(269-270)mo_net/model/layer/output.py (15)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-201)Serialized(57-69)Serialized(104-114)Serialized(127-139)output_dimensions(91-92)output_dimensions(197-198)serialize(52-52)serialize(94-95)serialize(200-201)deserialize(60-69)deserialize(107-114)deserialize(130-139)backward_prop(41-42)backward_prop_with_negative(151-157)mo_net/samples/word2vec/vocab.py (4)
get_training_set(132-151)serialize(27-34)deserialize(37-48)from_bytes(119-129)mo_net/train/trainer/trainer.py (4)
train(209-266)BasicTrainer(98-401)TrainingFailed(38-42)TrainingSuccessful(33-34)mo_net/config.py (1)
TrainingParameters(8-51)mo_net/train/run.py (2)
TrainingRun(6-70)seed(30-31)
🪛 Ruff (0.12.2)
mo_net/samples/word2vec/__main__.py
174-174: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
176-176: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
178-178: Prefer TypeError exception for invalid type
(TRY004)
178-178: Avoid specifying long messages outside the exception class
(TRY003)
204-204: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
344-344: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
346-346: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
348-348: Prefer TypeError exception for invalid type
(TRY004)
348-348: Avoid specifying long messages outside the exception class
(TRY003)
690-690: Avoid specifying long messages outside the exception class
(TRY003)
713-715: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
713-715: Avoid specifying long messages outside the exception class
(TRY003)
718-718: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
718-718: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Cursor Bugbot
- GitHub Check: typecheck
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (7)
mo_net/samples/word2vec/vocab.py (3)
306-315: Normalize and clean tokens before stopword filtering.Otherwise capitalized tokens and punctuation slip through.
- return [ - [word for word in sentence.split() if word not in stop_words] + return [ + [ + tok + for tok in (Vocab.clean_token(w) for w in sentence.split()) + if tok and tok not in stop_words + ] for sentence in ( get_resource("s3://mo-net-resources/english-sentences.txt") .read_text() .split("\n")[:limit] ) ]
16-17: Avoid Python 3.12-onlytypealias syntax; useTypeAlias.Keeps <=3.11 compatibility.
-type Sentence = Sequence[str] -type TokenizedSentence = Sequence[int] +Sentence: TypeAlias = Sequence[str] +TokenizedSentence: TypeAlias = Sequence[int]
9-9: Bug: typing.Counter used at runtime; import from collections.typing.Counter is not instantiable. This will crash when building vocab.
Apply:
-from typing import Counter, Self +from collections import Counter +from typing import Self, TypeAliasmo_net/train/trainer/trainer.py (1)
7-8: Fix Python 3.12-onlytypealiases and include missing “rmsprop”.Use
TypeAliasand add "rmsprop" to the union to match thematchcases.-from typing import Any, ContextManager, Final, Literal, assert_never +from typing import Any, ContextManager, Final, Literal, TypeAlias, assert_never @@ -type TrainingResult = TrainingSuccessful | TrainingFailed +TrainingResult: TypeAlias = TrainingSuccessful | TrainingFailed @@ -type OptimizerType = Literal["adam", "none"] +OptimizerType: TypeAlias = Literal["adam", "none", "rmsprop"]Also applies to: 47-52
mo_net/samples/word2vec/__main__.py (3)
545-556: Keepmodel_typeconsistent with loaded ZIP metadata.After loading, align CLI
model_typetoloaded_model_typeso data prep and logging match the artifact.if loaded_model_type == "skipgram": model: CBOWModel | SkipGramModel = SkipGramModel.load( mf, training=True, key=key ) else: model = CBOWModel.load(mf, training=True) vocab = Vocab.from_bytes(zf.read(VOCAB_ZIP_INTERNAL_PATH)) sentences = get_english_sentences() + model_type = loaded_model_type
275-281: Skip‑gram loss: logits/label shape mismatch.
Y_true.flatten()has lengthbatch_size * num_posbutY_predhasbatch_sizerows. Repeat logits per positive or collapse per column before loss.def compute_loss( self, X: jnp.ndarray, Y_true: jnp.ndarray, loss_fn: LossFn ) -> float: - Y_pred = self.forward_prop(X) - return loss_fn(Y_pred, Y_true.flatten()) + sum( + Y_pred = self.forward_prop(X) + if Y_true.ndim == 2: + # repeat each row once per positive context + Y_pred = jnp.repeat(Y_pred, repeats=Y_true.shape[1], axis=0) + y = Y_true.reshape(-1) + else: + y = Y_true + return loss_fn(Y_pred, y) + sum( contributor() for contributor in self.loss_contributors )
283-303: Skip‑gram backward: positive indices overrun row count.Flattening positives yields
batch_size * num_posrows while the output layer has onlybatch_sizerows. Aggregate per positive column.- def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]: - batch_size, context_size = Y_true.shape - self._key, subkey = jax.random.split(self._key) - - dZ = cast( - SparseCategoricalSoftmaxOutputLayer, self.output.output_layer - ).backward_prop_with_negative( - Y_true=Y_true.flatten(), - Y_negative=jax.random.choice( - subkey, - self.embedding_layer.vocab_size, - shape=(batch_size * context_size * self._negative_samples,), - ), - ) + def backward_prop(self, Y_true: jnp.ndarray) -> D[Activations]: + batch_size, num_pos = Y_true.shape + # Split RNG: one for state, one per positive column + splits = jax.random.split(self._key, num_pos + 1) + self._key = splits[0] + dZ = 0 + for i in range(num_pos): + negatives = jax.random.choice( + splits[i + 1], + self.embedding_layer.vocab_size, + shape=(batch_size * self._negative_samples,), + ) + dZ_i = cast( + SparseCategoricalSoftmaxOutputLayer, self.output.output_layer + ).backward_prop_with_negative( + Y_true=Y_true[:, i], + Y_negative=negatives, + ) + dZ = dZ + dZ_i
🧹 Nitpick comments (4)
mo_net/samples/word2vec/vocab.py (1)
70-81: Make vocab ordering deterministic (don’t go through a set).Using a set loses frequency order and makes IDs unstable between runs. Keep
most_commonorder and prependforced_wordsstably.- most_common_tokens = { - token - for token, _ in Counter( - token for sentence in sentences if sentence for token in sentence - ).most_common(max_size) - } - for word in forced_words: - most_common_tokens.add(word) - - vocab_tuple = tuple(most_common_tokens) + counts = Counter( + token for sentence in sentences if sentence for token in sentence + ).most_common(max_size) + most_common = [token for token, _ in counts] + # forced_words first; dict.fromkeys preserves order and de-dupes + vocab_list = list(dict.fromkeys([*forced_words, *most_common])) + vocab_tuple = tuple(vocab_list) unknown_token_id = len(vocab_tuple)Also applies to: 79-81
mo_net/train/trainer/trainer.py (2)
160-166: Silence unused-arg warnings in SIGINT handler.Prevents ARG002 from Ruff.
def _sigint_handler(self, signum: int, frame: Any) -> None: """Handle SIGINT (Ctrl+C) by setting interrupt flag.""" + del signum, frame self._interrupt_requested = True self._logger.info( "SIGINT received. Training will be interrupted at the next safe point." )
178-206: Don’t block on interactive prompt in non‑TTY/CI; auto‑continue.InquirerPy will fail or hang without a TTY. Gate on
stdin.isatty().def _handle_interrupt(self) -> TrainingResult | None: """Handle interrupt request by prompting user for action.""" if not self._interrupt_requested: return None self._interrupt_requested = False # Reset flag self._logger.info("Training interrupted by user. Prompting for action...") - try: + try: + import sys + if not sys.stdin or not sys.stdin.isatty(): + self._logger.warning("No interactive TTY detected; continuing training.") + return None choice = inquirer.select( message="Training has been interrupted. What would you like to do?",If preferred, add a constructor flag (e.g.,
interactive_prompts: bool = True) and disable prompts when False.mo_net/samples/word2vec/__main__.py (1)
204-209: Remove unused# noqa: F821.
Hidden | HiddenLayerare imported; the directive is unnecessary.- hidden: Sequence[Hidden | HiddenLayer], # noqa: F821 + hidden: Sequence[Hidden | HiddenLayer],
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
mo_net/samples/word2vec/__main__.py(1 hunks)mo_net/samples/word2vec/vocab.py(1 hunks)mo_net/train/trainer/trainer.py(15 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
mo_net/train/trainer/trainer.py (5)
mo_net/optimiser/base.py (12)
config(96-97)Base(18-104)set_model(84-85)restore(90-93)snapshot(87-88)learning_rate(82-82)learning_rate(123-124)report(78-78)report(119-120)training_step(27-32)training_step(35-42)training_step(44-72)mo_net/optimiser/adam.py (5)
AdaM(29-115)restore(104-115)snapshot(95-102)learning_rate(89-90)report(92-93)mo_net/optimiser/rmsprop.py (5)
RMSProp(27-96)restore(87-96)snapshot(81-85)learning_rate(75-76)report(78-79)mo_net/optimiser/scheduler.py (2)
CosineScheduler(72-120)WarmupScheduler(123-179)mo_net/train/run.py (1)
log_iteration(33-49)
mo_net/samples/word2vec/vocab.py (1)
mo_net/resources.py (1)
get_resource(16-47)
mo_net/samples/word2vec/__main__.py (12)
mo_net/__init__.py (1)
print_device_info(139-145)mo_net/functions.py (1)
sparse_cross_entropy(20-23)mo_net/log.py (2)
LogLevel(12-18)setup_logging(21-23)mo_net/model/layer/average.py (2)
Average(12-92)axis(84-85)mo_net/model/layer/embedding.py (2)
Embedding(117-307)vocab_size(306-307)mo_net/model/layer/linear.py (1)
Linear(182-395)mo_net/model/model.py (6)
output(435-436)Model(53-436)hidden_modules(276-277)output_module(290-291)input_layer(286-287)loss_contributors(269-270)mo_net/model/layer/output.py (15)
OutputLayer(22-52)SparseCategoricalSoftmaxOutputLayer(125-201)Serialized(57-69)Serialized(104-114)Serialized(127-139)output_dimensions(91-92)output_dimensions(197-198)serialize(52-52)serialize(94-95)serialize(200-201)deserialize(60-69)deserialize(107-114)deserialize(130-139)backward_prop(41-42)backward_prop_with_negative(151-157)mo_net/regulariser/weight_decay.py (4)
EmbeddingWeightDecayRegulariser(71-112)embedding_layer(68-68)attach(47-63)attach(101-112)mo_net/samples/word2vec/vocab.py (6)
get_english_sentences(306-315)get_training_set(142-161)serialize(27-34)deserialize(37-48)from_bytes(129-139)english_sentences(115-126)mo_net/train/trainer/trainer.py (4)
train(291-348)BasicTrainer(100-488)TrainingFailed(40-44)TrainingSuccessful(35-36)mo_net/config.py (1)
TrainingParameters(8-51)
🪛 Ruff (0.12.2)
mo_net/train/trainer/trainer.py
160-160: Unused method argument: signum
(ARG002)
160-160: Unused method argument: frame
(ARG002)
mo_net/samples/word2vec/__main__.py
174-174: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
176-176: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
178-178: Prefer TypeError exception for invalid type
(TRY004)
178-178: Avoid specifying long messages outside the exception class
(TRY003)
204-204: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
344-344: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
346-346: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
348-348: Prefer TypeError exception for invalid type
(TRY004)
348-348: Avoid specifying long messages outside the exception class
(TRY003)
701-701: Avoid specifying long messages outside the exception class
(TRY003)
724-726: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
724-726: Avoid specifying long messages outside the exception class
(TRY003)
729-729: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
729-729: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Cursor Bugbot
- GitHub Check: typecheck
e2681b7 to
a0c262f
Compare
a0c262f to
d10da17
Compare
54183e9 to
f8d0af8
Compare
- Fix double mean computation of loss. - Fix incorrect calculation in negative sample code.
f8d0af8 to
370475c
Compare
This change is
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Chores