Conversation
|
Warning Rate limit exceeded@modiase has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 4 minutes and 25 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (14)
WalkthroughThis update migrates the entire codebase from NumPy to JAX for all numerical operations, array types, and random number generation. It introduces JAX-based loss functions, device selection, and random key management throughout the model, training, and data modules. Additional features include support for sparse cross-entropy loss, new CLI options, a new MNIST CNN sample, asynchronous batched logging, and extensive test updates. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI/User
participant Trainer
participant Model
participant JAX
participant Device
CLI/User->>Trainer: Start training (with --loss-fn-name, --device, --seed)
Trainer->>JAX: Set device, initialize PRNG key with seed
Trainer->>Model: Initialize model (pass JAX key)
Trainer->>Trainer: Create batcher (with JAX key)
loop Each epoch/batch
Trainer->>JAX: Shuffle/select batch indices
Trainer->>Model: Forward pass (JAX arrays)
Model->>JAX: Compute activations/loss
Trainer->>Model: Backward pass (JAX arrays)
Model->>JAX: Compute gradients
Trainer->>Trainer: Update model parameters
end
Trainer->>Device: Optionally print device info
Trainer->>CLI/User: Report training result
sequenceDiagram
participant User
participant CLI
participant DataLoader
participant Model
participant Trainer
participant Logger
User->>CLI: Run training command (selects device, loss function, seed)
CLI->>DataLoader: Load data (optionally one-hot or sparse labels)
CLI->>Model: Initialize model with JAX key
CLI->>Trainer: Pass model, data, loss function, device, key
Trainer->>Logger: Log training progress (batched, async)
Trainer->>Model: Train using JAX arrays, random keys
Model->>JAX: All math, random, device ops via JAX
Possibly related PRs
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
8f6ac6e to
9a82352
Compare
There was a problem hiding this comment.
Actionable comments posted: 11
🔭 Outside diff range comments (2)
mo_net/tests/trainer/test_parallel.py (1)
206-249: Fix type incompatibility: Use JAX arrays instead of NumPy arrays.The pipeline failure indicates that these NumPy arrays are incompatible with the expected JAX Array types. Since the codebase has migrated to JAX, the test gradients should also use JAX arrays.
Replace NumPy arrays with JAX arrays:
+import jax.numpy as jnp + gradient1=Linear.Parameters( - weights=np.array([[1.0], [2.0]], dtype=np.float32), - biases=np.array([3.0], dtype=np.float32), + weights=jnp.array([[1.0], [2.0]], dtype=jnp.float32), + biases=jnp.array([3.0], dtype=jnp.float32), ), gradient2=Linear.Parameters( - weights=np.array([[4.0], [5.0]], dtype=np.float32), - biases=np.array([6.0], dtype=np.float32), + weights=jnp.array([[4.0], [5.0]], dtype=jnp.float32), + biases=jnp.array([6.0], dtype=jnp.float32), ),Apply similar changes to all other gradient arrays in the test cases.
mo_net/model/layer/layer_norm.py (1)
235-256: Wrap cached output inActivationsto satisfy the TypedDictThe
CacheTypedDict inlayer_norm.pydeclaresoutput_activations: Activations | Nonebut the code assigns a raw
jnp.ndarray(normalized) directly. That causes a mypy error assigningjnp.ndarraytoActivations. Wrapnormalizedin theActivationsNewType so it matches the declared type.• File:
mo_net/model/layer/layer_norm.py
Location: inside_forward_prop, around line 253
• Change:- if self._store_output_activations or self._training: - self._cache["output_activations"] = normalized + if self._store_output_activations or self._training: + # wrap in Activations for TypedDict compatibility + self._cache["output_activations"] = Activations(normalized)
🧹 Nitpick comments (7)
README.md (1)
52-130: Excellent GPU/MPS setup guide addition.The comprehensive device setup guide is well-structured and covers all essential aspects including installation, usage, performance tips, and troubleshooting. This will be very helpful for users adopting the JAX migration.
Address the minor documentation issues:
-To share GPU with other processes: +To share the GPU with other processes: -Check JAX CUDA version matches your CUDA installation +Check JAX CUDA version matches the CUDA installation -``` +```bash</blockquote></details> <details> <summary>mo_net/scripts/check_device.py (1)</summary><blockquote> `10-29`: **Well-implemented device checking script.** The script effectively demonstrates JAX device availability and functionality. The test operation using matrix multiplication is a good choice for verifying device capabilities. Consider adding basic error handling around the JAX operations to provide more informative messages if device operations fail: ```diff logger.info("\nTesting JAX operation...") + try: x = jax.numpy.ones((1000, 1000)) y = jax.numpy.dot(x, x) logger.info(f"Matrix multiplication result shape: {y.shape}") logger.info(f"Operation ran on: {y.device}") + except Exception as e: + logger.error(f"JAX operation failed: {e}")mo_net/model/layer/dropout.py (1)
8-9: Consider import optimization suggestion.The static analysis tool suggests using
from jax import randominstead ofimport jax.random as random. This would make the code cleaner:-import jax.numpy as jnp -import jax.random as random +import jax.numpy as jnp +from jax import randommo_net/device.py (2)
39-76: Simplify control flow by removing unnecessary elif/else after return statements.The device selection logic is correct, but the code structure can be simplified as suggested by static analysis.
if device_type == "auto": # Priority order: CUDA GPU > Metal/MPS > CPU if "gpu" in available: device = available["gpu"][0] logger.info(f"Auto-selected CUDA GPU: {device}") return device - elif "metal" in available: + if "metal" in available: device = available["metal"][0] logger.info(f"Auto-selected Metal/MPS device: {device}") return device - else: - device = available.get("cpu", jax.devices())[0] - logger.info(f"Auto-selected CPU: {device}") - return device + device = available.get("cpu", jax.devices())[0] + logger.info(f"Auto-selected CPU: {device}") + return device - elif device_type == "gpu": + if device_type == "gpu": if "gpu" in available: device = available["gpu"][0] logger.info(f"Selected CUDA GPU: {device}") return device - else: - raise RuntimeError("No CUDA GPU available") + raise RuntimeError("No CUDA GPU available") - elif device_type == "mps": + if device_type == "mps": if "metal" in available: device = available["metal"][0] logger.info(f"Selected Metal/MPS device: {device}") return device - else: - raise RuntimeError("No Metal/MPS device available") + raise RuntimeError("No Metal/MPS device available")
147-159: Consider adding type ignore comment for optional jax_metal import.The optional import pattern is correct, but mypy reports an import error. Since this is an optional dependency, consider suppressing the mypy error.
try: # This will be available if jax-metal is installed - import jax_metal # noqa: F401 + import jax_metal # noqa: F401 # type: ignore[import-not-found] logger.info("JAX Metal plugin loaded successfully")mo_net/model/layer/pool.py (2)
93-141: Backward propagation correctly implements gradient routing with proper tie handling.The implementation correctly distributes gradients to max positions and handles ties by normalizing the mask. The use of JAX's
.at[].add()for indexed updates is appropriate.Fix the whitespace style issues in slice notation:
- pooled_vals = pooled_output[:, :, h_out : h_out + 1, w_out : w_out + 1] + pooled_vals = pooled_output[:, :, h_out:h_out + 1, w_out:w_out + 1] - grad_out = cast(jnp.ndarray, dZ)[ - :, :, h_out : h_out + 1, w_out : w_out + 1 - ] + grad_out = cast(jnp.ndarray, dZ)[ + :, :, h_out:h_out + 1, w_out:w_out + 1 + ]
109-139: Consider optimizing the backward pass with vectorized operations.The current implementation uses nested Python loops which may not be efficiently JIT-compiled by JAX. For better performance, especially with large feature maps, consider vectorizing this operation.
Would you like me to help implement a vectorized version using JAX's
vmapor other functional transformations? This could significantly improve performance for large inputs.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (19)
README.md(1 hunks)mo_net/__init__.py(1 hunks)mo_net/cli.py(4 hunks)mo_net/constants.py(1 hunks)mo_net/device.py(1 hunks)mo_net/functions.py(6 hunks)mo_net/model/layer/base.py(2 hunks)mo_net/model/layer/batch_norm/batch_norm.py(12 hunks)mo_net/model/layer/convolution.py(8 hunks)mo_net/model/layer/dropout.py(5 hunks)mo_net/model/layer/layer_norm.py(7 hunks)mo_net/model/layer/linear.py(9 hunks)mo_net/model/layer/output.py(3 hunks)mo_net/model/layer/pool.py(2 hunks)mo_net/optimizer/adam.py(2 hunks)mo_net/protos.py(3 hunks)mo_net/scripts/check_device.py(1 hunks)mo_net/tests/trainer/test_parallel.py(3 hunks)pyproject.toml(3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (4)
mo_net/scripts/check_device.py (1)
mo_net/device.py (2)
get_available_devices(12-23)print_device_info(134-143)
mo_net/model/layer/pool.py (7)
mo_net/model/layer/base.py (2)
Hidden(98-108)_backward_prop(104-108)mo_net/protos.py (2)
D(30-38)d(162-163)mo_net/model/layer/dropout.py (2)
Cache(35-37)_backward_prop(82-93)mo_net/model/layer/convolution.py (2)
Cache(23-26)_backward_prop(320-428)mo_net/model/layer/linear.py (2)
Cache(178-180)_backward_prop(307-339)mo_net/model/layer/activation.py (2)
Cache(34-35)_backward_prop(56-59)mo_net/model/layer/reshape.py (2)
_backward_prop(31-32)_backward_prop(60-61)
mo_net/model/layer/output.py (4)
mo_net/model/layer/base.py (2)
backward_prop(99-101)_backward_prop(104-108)mo_net/model/model.py (1)
backward_prop(268-273)mo_net/model/base.py (1)
backward_prop(37-37)mo_net/protos.py (1)
D(30-38)
mo_net/model/layer/batch_norm/batch_norm.py (5)
mo_net/model/layer/convolution.py (1)
from_bytes(190-198)mo_net/model/layer/base.py (1)
input_dimensions(79-80)mo_net/model/layer/layer_norm.py (1)
from_bytes(159-167)mo_net/model/layer/linear.py (1)
from_bytes(165-172)mo_net/train/trainer/parallel.py (2)
IO(35-174)read(43-54)
🪛 GitHub Actions: CI
mo_net/tests/trainer/test_parallel.py
[error] 206-249: mypy: Multiple argument type incompatibilities: Argument "weights" and "biases" have incompatible type "ndarray[...]"; expected "Array" [arg-type]
mo_net/cli.py
[error] 395-395: mypy: Unexpected keyword argument "device" for "TrainingParameters" [call-arg]
mo_net/model/layer/dropout.py
[error] 45-80: mypy: Function "jax._src.random.PRNGKey" is not valid as a type; incompatible return value types for functions at lines 69 and 80 (expected "Activations", got "Array") [valid-type, return-value]
mo_net/model/layer/layer_norm.py
[error] 254-291: mypy: Incompatible types for "output_activations" and return values; type variable constraints violated [typeddict-item, return-value, type-var]
mo_net/device.py
[error] 15-15: mypy: Need type annotation for "device_map" (hint: "device_map: dict[, ] = ...") [var-annotated]
[error] 151-151: mypy: Cannot find implementation or library stub for module named "jax_metal" [import-not-found]
mo_net/functions.py
[error] 11-23: mypy: Incompatible return value types (got "Array", expected "float" and "_X") [return-value]
mo_net/model/layer/linear.py
[error] 101-245: mypy: Function "jax._src.random.PRNGKey" is not valid as a type; Name "random.PRNGKey" is not defined [valid-type, name-defined]
🪛 LanguageTool
README.md
[uncategorized] ~106-~106: You might be missing the article “the” here.
Context: ...locates GPU memory by default. To share GPU with other processes: ```python f...
(AI_EN_LECTOR_MISSING_DETERMINER_THE)
[uncategorized] ~119-~119: You might be missing the article “the” here.
Context: ... CUDA installation - Reinstall JAX with correct CUDA version ### Metal/MPS Not Working...
(AI_EN_LECTOR_MISSING_DETERMINER_THE)
🪛 markdownlint-cli2 (0.17.2)
README.md
130-130: Fenced code blocks should have a language specified
null
(MD040, fenced-code-language)
🪛 Pylint (3.3.7)
mo_net/model/layer/dropout.py
[refactor] 9-9: Use 'from jax import random' instead
(R0402)
mo_net/model/layer/pool.py
[error] 9-9: No name 'layer' in module 'mo_net.model'
(E0611)
[error] 10-10: Cannot import 'mo_net.protos' due to 'invalid syntax (mo_net.protos, line 56)'
(E0001)
[error] 10-10: No name 'protos' in module 'mo_net'
(E0611)
[refactor] 93-93: Too many local variables (22/15)
(R0914)
mo_net/device.py
[refactor] 41-52: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
[refactor] 55-60: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
[refactor] 63-68: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
mo_net/functions.py
[error] 7-7: Cannot import 'mo_net.protos' due to 'invalid syntax (mo_net.protos, line 56)'
(E0001)
[error] 7-7: No name 'protos' in module 'mo_net'
(E0611)
[refactor] 36-41: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
[refactor] 51-56: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
[refactor] 74-79: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
[refactor] 97-102: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
[refactor] 123-128: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it
(R1705)
🪛 Flake8 (7.2.0)
mo_net/model/layer/pool.py
[error] 121-121: whitespace before ':'
(E203)
[error] 121-121: whitespace before ':'
(E203)
[error] 135-135: whitespace before ':'
(E203)
[error] 135-135: whitespace before ':'
(E203)
🔇 Additional comments (25)
mo_net/constants.py (1)
4-4: Correct alignment with JAX's float32 default.The change from 8 to 4 bytes per float aligns well with JAX's default float32 precision, which is optimal for GPU performance.
mo_net/model/layer/base.py (2)
10-10: JAX migration import change looks good.The import change from NumPy to JAX NumPy is correct and consistent with the broader codebase migration.
67-67: Correct usage of JAX NumPy function.The change from
np.atleast_2dtojnp.atleast_2dmaintains the same functionality while using JAX arrays, which is appropriate for the migration.mo_net/__init__.py (1)
1-1: Good default JAX platform configuration.Setting
JAX_PLATFORM_NAMEto "cpu" by default provides a sensible fallback that ensures consistent behavior across different environments. This complements the new device management system well.Also applies to: 9-9
mo_net/optimizer/adam.py (2)
5-5: LGTM! Clean JAX import migration.The import change from NumPy to JAX NumPy is correctly implemented and aligns with the broader codebase migration.
49-50: Type annotations correctly updated for JAX.The type annotations have been properly updated from
np.ndarraytojnp.ndarray, maintaining consistency with the JAX migration while preserving the original logic.pyproject.toml (3)
32-33: JAX dependencies correctly added.The core JAX dependencies are properly specified with appropriate version constraints.
37-39: Device-specific dependency groups well-organized.The separate dependency groups for CUDA and MPS support provide clean organization for different hardware acceleration options.
Also applies to: 51-53
84-85: New script entries properly configured.The script entries for
check_deviceandcross_validateare correctly defined and match the expected file structure.mo_net/cli.py (3)
23-23: Device imports correctly added.The import of device management functions is properly placed and follows the existing import organization.
257-262: Device selection option well-implemented.The CLI option provides appropriate choices and clear help text for device selection.
380-384: Device configuration properly placed.Setting the default device before data loading is the correct approach, and conditional device info printing respects the quiet flag.
mo_net/model/layer/dropout.py (1)
67-69: Good addition of edge case handling.Adding explicit handling for very small keep probabilities to avoid NaN values is a thoughtful improvement that enhances robustness.
Also applies to: 86-88
mo_net/model/layer/layer_norm.py (4)
5-9: Import and type annotation updates look good.The migration from NumPy to JAX NumPy is correctly implemented for imports and type annotations.
Also applies to: 28-29, 33-33
155-157: Array creation methods correctly migrated to JAX.The use of
jnp.onesandjnp.zerosis appropriate for JAX array initialization.
273-291: Backward propagation correctly migrated to JAX.The gradient computations are properly implemented using JAX NumPy operations. The explicit array conversion at line 279 is necessary for JAX compatibility.
296-298: Gradient initialization correctly uses JAX operations.The use of
jnp.zeros_likeis appropriate for creating gradient arrays with matching shapes and dtypes.mo_net/model/layer/output.py (2)
6-6: Import correctly updated to JAX NumPy.The migration from NumPy to JAX NumPy is properly implemented.
85-85: Array operation correctly migrated to JAX.The use of
jnp.atleast_1dis appropriate for ensuring the gradient has at least one dimension.mo_net/protos.py (1)
5-7: Protocol and type definitions correctly migrated to JAX.All type annotations and protocol definitions have been properly updated for JAX compatibility. The
d_ophelper function documentation clearly explains the type casting requirement, which is helpful for maintainability.Also applies to: 10-10, 12-12, 49-49, 53-53, 166-173
mo_net/device.py (1)
79-123: Excellent handling of Metal backend compatibility issues.The implementation includes comprehensive testing for Metal device compatibility with proper fallback to CPU. The error handling and logging are well-designed to handle known JAX Metal backend limitations.
mo_net/model/layer/pool.py (1)
6-7: Import and type updates correctly implemented for JAX.The migration from NumPy to JAX is properly reflected in imports and type annotations.
Also applies to: 15-15
mo_net/model/layer/batch_norm/batch_norm.py (1)
7-239: Clean JAX migration for batch normalization layer.The migration from NumPy to JAX is complete and consistent throughout the file. All array operations, type annotations, and parameter initialization have been properly updated to use JAX equivalents.
mo_net/model/layer/convolution.py (1)
303-318: Excellent use of JAX's native convolution operations.The migration to
lax.conv_general_dilatedfor forward propagation is a significant improvement over manual implementation, providing better performance and correctness guarantees.mo_net/model/layer/linear.py (1)
293-305: Good use of JIT compilation for performance.The addition of JIT compilation for the linear operations and gradient computations will provide significant performance improvements, especially for larger matrices.
Also applies to: 316-336
| def backward_prop(self, *, Y_true: jnp.ndarray) -> D[Activations]: | ||
| return self._backward_prop(Y_true=Y_true) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Type compatibility issue with base model interface.
The Y_true parameter type has been changed to jnp.ndarray, but the base model classes (mo_net/model/base.py and mo_net/model/model.py) still expect np.ndarray for the backward_prop method. This creates a type mismatch that could cause issues.
Run the following script to verify the interface compatibility:
Consider either:
- Updating the base model interfaces to use
jnp.ndarray - Ensuring
jnp.ndarrayis compatible withnp.ndarraytype annotations - Adding type conversions if needed
Also applies to: 48-48, 81-81
🏁 Script executed:
#!/bin/bash
# Description: Check the base model backward_prop method signatures
# Find backward_prop method signatures in base model classes
ast-grep --pattern 'def backward_prop($$$Y_true: $_$$$) -> $_: $$$'
# Check if there are any conversions between np.ndarray and jnp.ndarray
rg -A 2 'Y_true.*ndarray' Length of output: 1859
Ensure consistent Y_true types across model interfaces
The backward_prop and compute_loss signatures in your core model layers still use np.ndarray, while mo_net/model/layer/output.py now expects a jnp.ndarray. This discrepancy will break static typing and may cause runtime errors when mixing JAX arrays with NumPy annotations.
Please update the following methods to use jnp.ndarray (or add conversion logic) so that every layer and model interface agrees on the same array type:
- mo_net/model/base.py
•def backward_prop(self, *, Y_true: np.ndarray) -> None
•def compute_loss(self, X: np.ndarray, Y_true: np.ndarray) -> float - mo_net/model/module/base.py
•def backward_prop(self, *, Y_true: np.ndarray) -> D[Activations] - mo_net/model/model.py
•def backward_prop(self, Y_true: np.ndarray) -> None
•def compute_loss(self, X: np.ndarray, Y_true: np.ndarray) -> float
Options:
- Change all
Y_true: np.ndarray→Y_true: jnp.ndarray. - Keep NumPy types and explicitly convert JAX arrays at the layer boundary.
🤖 Prompt for AI Agents
In mo_net/model/base.py (lines with backward_prop and compute_loss),
mo_net/model/module/base.py (backward_prop), and mo_net/model/model.py
(backward_prop and compute_loss), update the type annotations for the Y_true
parameter from np.ndarray to jnp.ndarray to ensure consistent typing with
mo_net/model/layer/output.py. This involves changing all occurrences of Y_true:
np.ndarray to Y_true: jnp.ndarray in the method signatures. Alternatively, if
you prefer to keep np.ndarray types, add explicit conversion from jnp.ndarray to
np.ndarray at the layer boundaries, but the preferred fix is to unify all to
jnp.ndarray for consistency.
|
|
||
| self._cache["dP"] = d(self.Parameters(weights=dW, biases=dB)) | ||
| return dZ @ self._parameters.weights.T | ||
| return dX |
There was a problem hiding this comment.
Wrap return value to match expected type.
The backward propagation should return D[Activations] but currently returns a bare array.
Apply this diff to fix the return type:
- return dX
+ return d(Activations(dX))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return dX | |
| return d(Activations(dX)) |
🤖 Prompt for AI Agents
In mo_net/model/layer/linear.py at line 339, the backward method currently
returns a bare array dX, but it should return a wrapped object of type
D[Activations]. Modify the return statement to wrap dX appropriately so that the
return type matches the expected D[Activations] type.
There was a problem hiding this comment.
Actionable comments posted: 10
🔭 Outside diff range comments (2)
mo_net/model/layer/average.py (1)
6-79: JAX migration is correct, consider addingstrict=Truetozip()The migration to JAX NumPy is well done. The backward propagation logic correctly implements gradient computation through averaging operations using JAX-idiomatic patterns.
To ensure the two iterables have matching lengths and catch potential bugs early, add
strict=Trueto thezip()call:- for ax, factor in zip( - sorted(ax + 1 for ax in self._axis), - tuple(self._cache["input_shape"][ax + 1] for ax in self._axis), - ): + for ax, factor in zip( + sorted(ax + 1 for ax in self._axis), + tuple(self._cache["input_shape"][ax + 1] for ax in self._axis), + strict=True, + ):mo_net/samples/cbow.py (1)
4-16: Fix module name collisionThe import
import jax.random as randomshadows Python's standardrandommodule imported on line 4, which could lead to confusion.Use a different alias to avoid the name collision:
-import jax.random as random +import jax.random as jrandomThen update all usages accordingly:
- Line 591:
jrandom.choice- Line 592:
jrandom.PRNGKey- Line 633:
jrandom.choice- Line 634:
jrandom.PRNGKey
♻️ Duplicate comments (3)
mo_net/model/layer/linear.py (2)
106-107: Fix type annotation for PRNGKey parameters.The type annotation
random.PRNGKeyis not valid. Usejnp.ndarrayinstead.- cls, dim_in: Dimensions, dim_out: Dimensions, key: random.PRNGKey | None = None + cls, dim_in: Dimensions, dim_out: Dimensions, key: jnp.ndarray | None = NoneApply this change to all methods:
random,xavier,he,appropriate, and the__init__parameter.Also applies to: 120-121, 133-134, 150-151, 252-252
346-346: Wrap return value to match expected type.The backward propagation should return
D[Activations]but currently returns a bare array.- return dX + return d(Activations(dX))mo_net/cli.py (1)
386-386: Fix missing device field in TrainingParameters.Based on previous review feedback,
TrainingParametersdoesn't have adeviceparameter, causing pipeline failures.Either remove the
device=deviceargument or add adevicefield toTrainingParametersinmo_net/config.py.
🧹 Nitpick comments (4)
README.md (1)
52-130: Excellent GPU/MPS setup documentation with minor grammar fixes needed.The comprehensive setup guide is well-structured and covers all essential aspects. Please address the grammar issues identified by static analysis:
-1. **Memory Management**: JAX pre-allocates GPU memory by default. To share GPU with other processes: +1. **Memory Management**: JAX pre-allocates GPU memory by default. To share the GPU with other processes:-- Check JAX CUDA version matches your CUDA installation -- Reinstall JAX with correct CUDA version +- Check JAX CUDA version matches your CUDA installation +- Reinstall JAX with the correct CUDA version-``` +```bash +# Example commands for checking device status +```mo_net/tests/model/layer/test_activation.py (1)
16-16: Update type annotations for consistency.The type annotations still reference
Activationswhich may not be the correct type for JAX arrays. Consider updating to usejnp.ndarrayconsistently.-def test_relu_forward_prop(X: Activations, expected: jnp.ndarray): +def test_relu_forward_prop(X: jnp.ndarray, expected: jnp.ndarray):-def test_relu_backward_prop(X: Activations, expected: jnp.ndarray): +def test_relu_backward_prop(X: jnp.ndarray, expected: jnp.ndarray):Also applies to: 31-31
mo_net/tests/model/layer/test_dropout.py (2)
16-27: Consider testing stochastic dropout behavior.The current tests use deterministic probabilities (1.0 and near-zero) which avoids randomness but doesn't test the stochastic nature of dropout. Consider adding tests with intermediate probabilities using JAX's explicit random keys.
def test_dropout_stochastic_behavior(): import jax.random as random dropout = Dropout( input_dimensions=(3,), keep_prob=0.5, training=True, ) # Test with explicit random key key = random.PRNGKey(42) # Add test logic here
20-20: Update type annotations for consistency.The type annotations still reference
Activationswhich may not be the correct type for JAX arrays. Consider updating to usejnp.ndarrayconsistently.-def test_dropout_forward_prop(keep_prob: float, expected: jnp.ndarray): +def test_dropout_forward_prop(keep_prob: float, expected: jnp.ndarray):-def test_dropout_backward_prop(keep_prob: float, expected: jnp.ndarray): +def test_dropout_backward_prop(keep_prob: float, expected: jnp.ndarray):Also applies to: 36-36
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (51)
README.md(1 hunks)mo_net/__init__.py(1 hunks)mo_net/cli.py(8 hunks)mo_net/constants.py(1 hunks)mo_net/data.py(3 hunks)mo_net/functions.py(1 hunks)mo_net/model/base.py(3 hunks)mo_net/model/layer/activation.py(4 hunks)mo_net/model/layer/average.py(2 hunks)mo_net/model/layer/base.py(2 hunks)mo_net/model/layer/batch_norm/batch_norm.py(13 hunks)mo_net/model/layer/batch_norm/batch_norm_2d.py(13 hunks)mo_net/model/layer/convolution.py(8 hunks)mo_net/model/layer/dropout.py(5 hunks)mo_net/model/layer/embedding.py(5 hunks)mo_net/model/layer/input.py(1 hunks)mo_net/model/layer/layer_norm.py(7 hunks)mo_net/model/layer/linear.py(8 hunks)mo_net/model/layer/output.py(3 hunks)mo_net/model/layer/pool.py(2 hunks)mo_net/model/model.py(6 hunks)mo_net/model/module/base.py(2 hunks)mo_net/model/module/convolution.py(2 hunks)mo_net/optimizer/adam.py(2 hunks)mo_net/optimizer/base.py(3 hunks)mo_net/optimizer/rmsprop.py(2 hunks)mo_net/protos.py(3 hunks)mo_net/quickstart.py(3 hunks)mo_net/regulariser/weight_decay.py(2 hunks)mo_net/samples/cbow.py(8 hunks)mo_net/samples/mnist/__main__.py(2 hunks)mo_net/samples/mnist/features.py(7 hunks)mo_net/scripts/trace_statistics.py(8 hunks)mo_net/scripts/validation.py(2 hunks)mo_net/tests/model/layer/test_activation.py(1 hunks)mo_net/tests/model/layer/test_average.py(1 hunks)mo_net/tests/model/layer/test_batch_norm.py(15 hunks)mo_net/tests/model/layer/test_convolution.py(23 hunks)mo_net/tests/model/layer/test_dropout.py(1 hunks)mo_net/tests/model/layer/test_linear.py(24 hunks)mo_net/tests/model/layer/test_pool.py(5 hunks)mo_net/tests/model/layer/test_reshape.py(2 hunks)mo_net/tests/model/test_model.py(5 hunks)mo_net/tests/trainer/test_parallel.py(11 hunks)mo_net/train/augment.py(5 hunks)mo_net/train/batcher.py(4 hunks)mo_net/train/monitor.py(6 hunks)mo_net/train/tracer.py(8 hunks)mo_net/train/trainer/parallel.py(7 hunks)mo_net/train/trainer/trainer.py(4 hunks)pyproject.toml(2 hunks)
✅ Files skipped from review due to trivial changes (11)
- mo_net/regulariser/weight_decay.py
- mo_net/quickstart.py
- mo_net/tests/model/layer/test_pool.py
- mo_net/tests/model/layer/test_reshape.py
- mo_net/optimizer/rmsprop.py
- mo_net/model/module/base.py
- mo_net/model/base.py
- mo_net/train/trainer/trainer.py
- mo_net/model/layer/batch_norm/batch_norm_2d.py
- mo_net/data.py
- mo_net/model/module/convolution.py
🚧 Files skipped from review as they are similar to previous changes (12)
- mo_net/constants.py
- mo_net/model/layer/base.py
- pyproject.toml
- mo_net/optimizer/adam.py
- mo_net/tests/trainer/test_parallel.py
- mo_net/model/layer/dropout.py
- mo_net/protos.py
- mo_net/model/layer/output.py
- mo_net/model/layer/layer_norm.py
- mo_net/model/layer/pool.py
- mo_net/model/layer/convolution.py
- mo_net/model/layer/batch_norm/batch_norm.py
🧰 Additional context used
🧬 Code Graph Analysis (9)
mo_net/model/layer/input.py (4)
mo_net/model/layer/output.py (3)
backward_prop(41-42)_backward_prop(45-49)_backward_prop(81-88)mo_net/model/model.py (1)
backward_prop(270-274)mo_net/model/layer/base.py (2)
backward_prop(102-104)_backward_prop(107-111)mo_net/protos.py (1)
D(24-32)
mo_net/model/layer/average.py (3)
mo_net/model/layer/embedding.py (1)
_backward_prop(203-218)mo_net/model/layer/batch_norm/batch_norm.py (1)
_backward_prop(294-341)mo_net/protos.py (1)
D(24-32)
mo_net/tests/model/layer/test_linear.py (7)
mo_net/model/layer/linear.py (7)
random(105-116)Linear(191-409)parameters(384-385)Parameters(37-177)cache(380-381)update_parameters(369-374)empty_gradient(348-354)mo_net/model/layer/embedding.py (6)
random(77-79)parameters(254-255)Parameters(26-105)cache(250-251)update_parameters(239-244)empty_gradient(220-223)mo_net/model/layer/convolution.py (6)
random(116-137)parameters(449-450)Parameters(33-198)cache(445-446)update_parameters(452-457)empty_gradient(436-442)mo_net/model/layer/base.py (2)
backward_prop(102-104)forward_prop(66-76)mo_net/model/model.py (4)
backward_prop(270-274)forward_prop(263-268)update_parameters(276-279)output(398-399)mo_net/protos.py (5)
parameters(120-120)cache(123-123)forward_prop(132-132)update_parameters(93-93)empty_gradient(127-127)mo_net/model/layer/layer_norm.py (5)
parameters(326-327)Parameters(27-167)cache(322-323)update_parameters(308-313)empty_gradient(297-303)
mo_net/samples/mnist/__main__.py (3)
mo_net/data.py (3)
load_data(101-103)load_data(105-107)load_data(110-118)mo_net/model/model.py (1)
predict(320-321)mo_net/model/base.py (1)
predict(16-16)
mo_net/model/layer/embedding.py (4)
mo_net/model/layer/batch_norm/batch_norm_2d.py (2)
from_bytes(122-130)empty_gradient(331-337)mo_net/model/layer/batch_norm/batch_norm.py (3)
from_bytes(159-167)Parameters(27-167)empty_gradient(343-349)mo_net/model/layer/layer_norm.py (3)
from_bytes(159-167)Parameters(27-167)empty_gradient(297-303)mo_net/protos.py (3)
d(157-158)empty_gradient(127-127)D(24-32)
mo_net/samples/cbow.py (3)
mo_net/model/layer/embedding.py (3)
random(77-79)parameters(254-255)vocab_size(281-282)mo_net/model/model.py (1)
forward_prop(263-268)mo_net/model/base.py (1)
forward_prop(34-34)
mo_net/train/batcher.py (3)
mo_net/model/layer/linear.py (1)
random(105-116)mo_net/model/layer/embedding.py (1)
random(77-79)mo_net/model/layer/convolution.py (1)
random(116-137)
mo_net/optimizer/base.py (2)
mo_net/protos.py (1)
SupportsGradientOperations(73-88)mo_net/functions.py (1)
cross_entropy(9-12)
mo_net/tests/model/layer/test_batch_norm.py (8)
mo_net/model/layer/linear.py (6)
random(105-116)parameters(384-385)Parameters(37-177)cache(380-381)serialize(361-367)deserialize(223-236)mo_net/model/layer/embedding.py (6)
random(77-79)parameters(254-255)Parameters(26-105)cache(250-251)serialize(230-237)deserialize(133-141)mo_net/model/layer/batch_norm/batch_norm.py (6)
BatchNorm(184-407)parameters(372-373)Parameters(27-167)cache(368-369)serialize(375-383)deserialize(199-214)mo_net/model/model.py (5)
input_dimensions(70-71)output(398-399)backward_prop(270-274)forward_prop(263-268)serialize(329-334)mo_net/protos.py (6)
input_dimensions(176-176)parameters(120-120)cache(123-123)forward_prop(132-132)serialize(149-149)deserialize(139-144)mo_net/model/layer/base.py (3)
input_dimensions(82-83)backward_prop(102-104)forward_prop(66-76)mo_net/model/layer/layer_norm.py (5)
parameters(326-327)Parameters(27-167)cache(322-323)serialize(329-334)deserialize(195-207)mo_net/model/base.py (2)
backward_prop(37-37)forward_prop(34-34)
🪛 Ruff (0.11.9)
mo_net/model/layer/average.py
73-76: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
mo_net/model/model.py
12-12: Redefinition of unused Callable from line 4
(F811)
mo_net/__init__.py
1-1: os imported but unused
(F401)
🪛 LanguageTool
README.md
[uncategorized] ~106-~106: You might be missing the article “the” here.
Context: ...locates GPU memory by default. To share GPU with other processes: ```python f...
(AI_EN_LECTOR_MISSING_DETERMINER_THE)
[uncategorized] ~119-~119: You might be missing the article “the” here.
Context: ... CUDA installation - Reinstall JAX with correct CUDA version ### Metal/MPS Not Working...
(AI_EN_LECTOR_MISSING_DETERMINER_THE)
🪛 markdownlint-cli2 (0.17.2)
README.md
130-130: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: typecheck
🔇 Additional comments (50)
mo_net/scripts/trace_statistics.py (2)
8-9: LGTM! Clean JAX migration for imports.The migration from NumPy to JAX NumPy is correctly implemented. The inquirer import change is also noted.
114-114: LGTM! Consistent JAX NumPy usage for array operations.All
np.diff(bins)calls have been correctly updated tojnp.diff(bins), maintaining consistent usage of JAX NumPy throughout the function. Thedifffunction behavior is identical between NumPy and JAX NumPy, so this change preserves functionality while enabling JAX optimizations.Also applies to: 129-129, 148-148, 167-167, 188-188, 211-211, 232-232
mo_net/model/layer/input.py (1)
23-24: LGTM! Consistent backward propagation interface.The addition of the public
backward_propmethod follows the established pattern seen in other layers (as shown inmo_net/model/layer/base.pylines 102-111). This provides a consistent interface for the model's backward propagation workflow, where the model callsbackward_propon each layer (seemo_net/model/model.pyline 274).mo_net/samples/mnist/__main__.py (4)
6-6: LGTM! JAX NumPy import for migration.The import change from NumPy to JAX NumPy is correctly implemented as part of the broader JAX migration.
100-100: LGTM! Consistent JAX NumPy usage for label extraction.The
argmaxoperations are correctly migrated tojnp.argmax, maintaining the same functionality while enabling JAX optimizations. This is consistent with the model's predict method which returnsjnp.ndarray(as seen inmo_net/model/model.pylines 320-321).Also applies to: 107-107
102-102: LGTM! Proper JAX NumPy usage for accuracy calculations.The accuracy and metric calculations are correctly migrated to use
jnp.sumfor computing matches and ratios. The logic remains identical while benefiting from JAX's performance optimizations.Also applies to: 109-109, 112-113
120-120: LGTM! Consistent JAX NumPy usage for data analysis.All array operations for sampling, counting, and plotting have been correctly migrated to JAX NumPy functions (
jnp.where,jnp.sum,jnp.arange). The functionality remains identical while enabling JAX optimizations.Also applies to: 128-129, 131-131, 136-136
mo_net/scripts/validation.py (2)
11-11: LGTM! JAX NumPy import for migration.The import change from NumPy to JAX NumPy is correctly implemented as part of the broader JAX migration.
197-200: LGTM! Consistent JAX NumPy usage for statistical calculations.The statistical functions (
min,max,mean,std) are correctly migrated to JAX NumPy equivalents. These functions have identical APIs and behavior, ensuring the validation statistics are computed correctly while benefiting from JAX's performance optimizations.mo_net/tests/model/layer/test_average.py (3)
1-2: LGTM! Clean migration to JAX with proper testing compatibility.The mixed approach of using JAX NumPy for computations (
jnp) while keeping NumPy testing utilities (np_testing) is a good pattern that maintains compatibility with existing test infrastructure while leveraging JAX's capabilities.
26-29: Consistent JAX array usage in multi-axis averaging.The use of
jnp.arangeandjnp.meanfor multi-axis operations is correct. JAX handles multi-axis operations efficiently and the test logic remains sound.
10-13: Average layer forward_prop returns JAX arraysThe
Average._forward_propmethod wraps the result ofjnp.meanin theActivationsNewType (which is ajnp.ndarray), and the baseforward_propalso returns anActivationsobject. All operations remain in JAX territory, so the test usingjnp.meanis valid and no changes are required.mo_net/tests/model/layer/test_activation.py (2)
1-2: LGTM! Proper JAX imports for activation testing.The import structure correctly brings in both JAX and JAX NumPy for activation functions and array operations.
19-21: Excellent migration to JAX's native ReLU.Using
jax.nn.reluinstead of a custom implementation is a significant improvement. JAX's built-in activation functions are optimized and more reliable than custom implementations.mo_net/tests/model/test_model.py (4)
4-4: LGTM! Clean import migration to JAX NumPy.The import change from
numpytojax.numpy as jnpis correct and consistent with the project's JAX migration.
40-43: Verify JAX array compatibility withatleast_2d.The migration from
np.atleast_2dtojnp.atleast_2dlooks correct. JAX'satleast_2dshould behave identically to NumPy's version for this use case.
71-95: Complex mathematical operations properly migrated.The comprehensive mathematical operations in the linear model test have been correctly migrated to use JAX arrays. The logic remains intact while leveraging JAX's numerical capabilities.
100-112: Serialization compatibility maintained with JAX.The serialization/deserialization test properly uses JAX arrays and maintains compatibility with the model's pickle-based serialization. The
jnp.allcloseassertion ensures numerical precision is preserved.mo_net/tests/model/layer/test_dropout.py (2)
4-4: LGTM! Proper JAX NumPy import.The import change from
numpytojax.numpy as jnpis correct and consistent with the project's JAX migration.
10-10: JAX array type correctly used for test input.The test input definition using
jnp.arrayis correct and maintains type consistency with the JAX migration.mo_net/tests/model/layer/test_linear.py (1)
4-590: Comprehensive JAX migration looks good!The systematic replacement of NumPy with JAX throughout the test suite is correct and well-executed. The random number generation properly uses JAX's PRNG key splitting pattern (e.g., lines 475-476), and all array operations and type annotations have been appropriately updated.
mo_net/train/monitor.py (1)
7-132: Clean JAX migration!All NumPy operations and type annotations have been correctly updated to use JAX NumPy.
mo_net/train/tracer.py (1)
8-43: JAX migration is correct - verify if deterministic sampling is intentionalThe migration to JAX is well done. However, using a fixed
PRNGKey(0)means the sampling will always produce the same sequence.Is the deterministic sampling behavior intentional for reproducibility? If truly random sampling is desired, consider using a different key strategy (e.g., deriving from iteration number or maintaining a key state).
mo_net/samples/cbow.py (1)
189-657: JAX migration is comprehensive and correctAll NumPy operations, type annotations, and array manipulations have been properly converted to use JAX NumPy. The random sampling correctly uses JAX's PRNG keys.
mo_net/tests/model/layer/test_batch_norm.py (5)
3-12: LGTM! Correct JAX import pattern.The migration from NumPy to JAX imports is done correctly, and creating a global random key with a fixed seed is appropriate for reproducible tests.
18-25: Type annotations correctly updated to JAX arrays.All dataclass field types have been consistently updated from
np.ndarraytojnp.ndarray, which is the correct approach for the JAX migration.Also applies to: 139-145, 232-240
35-95: Array creation correctly migrated to JAX.All test data arrays are properly created using
jnp.array(), and the random data generation on line 87 correctly uses JAX's random API with the predefined key.
114-132: Assertions correctly use JAX functions.All NumPy assertion functions have been properly replaced with their JAX equivalents (
jnp.allclose,jnp.mean,jnp.std).Also applies to: 126-132, 282-288, 334-336, 343-345, 395-404
439-449: Correct use of JAX random for gradient flow testing.The random gradient generation is appropriate here since the test only verifies shapes and finiteness rather than specific values. The key splitting follows JAX best practices.
mo_net/train/batcher.py (2)
4-5: Import and type changes correctly migrated to JAX.The imports and type annotations have been properly updated to use JAX arrays.
Also applies to: 12-15
26-28: Array operations correctly migrated to JAX.All array splitting and type annotations have been properly updated to use JAX equivalents.
Also applies to: 40-40, 49-51, 62-66, 71-75, 80-81, 86-91
mo_net/model/layer/activation.py (3)
4-9: Improved activation function interface with JAX.The change from a specific
ActivationFntype toCallable[[jnp.ndarray], jnp.ndarray]provides better flexibility and aligns with JAX's functional approach.Also applies to: 43-43
59-83: Excellent optimization with manual derivatives and autodiff fallback.The implementation correctly computes derivatives for common activation functions manually (which is faster) and falls back to JAX's automatic differentiation for unknown functions. All derivative formulas are correct.
92-101: Serialization updated for dictionary-based activation functions.The implementation correctly finds the activation function name by reference equality. Note that this approach only works for functions registered in
ACTIVATION_FUNCTIONS- custom activation functions would need to be registered there to be serializable.mo_net/tests/model/layer/test_convolution.py (2)
3-4: Correct migration to JAX activation functions.The import changes and the switch from
ReLUclass tojax.nn.relufunction align with the new functional approach for activation functions.Also applies to: 203-203
22-23: Comprehensive and correct JAX migration.All NumPy arrays, operations, and assertions have been systematically replaced with their JAX equivalents throughout the test file.
Also applies to: 34-35, 42-43, 50-51, 58-59, 66-67, 74-75, 98-99, 107-108, 120-122, 130-132, 143-147, 158-162, 173-187, 215-215, 225-227, 239-242, 250-262, 270-281, 289-302, 310-313, 321-324, 332-335, 350-352, 368-373, 384-391, 398-407, 414-421, 437-439, 445-451
mo_net/train/trainer/parallel.py (4)
15-16: JAX imports correctly added for parallel training.The migration to JAX numpy and random modules is consistent with the rest of the codebase.
460-465: Shared memory arrays correctly created with JAX.The migration properly handles creating JAX arrays from shared memory buffers with the appropriate dtypes.
Also applies to: 490-499
529-534: Smart use of worker ID for random batch sampling.Using
PRNGKey(worker_id)ensures each worker gets different random batches, which is essential for proper parallel training. This is a better approach than using fixed keys.
739-751: Data sharing implementation correctly uses JAX arrays.The code properly creates JAX array views of shared memory buffers and uses
jnp.copytofor data transfer. Type annotations are also correctly updated.Also applies to: 815-816
mo_net/functions.py (2)
9-12: LGTM!The cross_entropy implementation correctly uses JAX operations and properly averages over the batch dimension. The return type issue from the previous review has been addressed.
15-26: Clean refactor to JAX-native activation functions.The migration to JAX's built-in activation functions is excellent. This eliminates the previous LeakyReLU implementation bug and provides optimized implementations.
mo_net/samples/mnist/features.py (2)
20-25: Correct usage of JAX's immutable array syntax.The implementation properly uses JAX's
.at[].set()syntax for updating immutable arrays.
300-304: Fix file I/O for saving JAX arraysJAX’s
jnpnamespace does not include any file-I/O functions (e.g.savez), so usingjnp.savezwill result in a runtime error. Replace it with NumPy’ssavez(orsavez_compressed) to persist JAX arrays.• File: mo_net/samples/mnist/features.py
Lines: ~300–304+ import numpy as np if args.output: logger.info(f"Saving trained input vector to: {args.output}") - jnp.savez( + np.savez( args.output, values=input_vector, input_dimensions=input_vector.shape[1:], )Please confirm that
jnp.savezis indeed unavailable in your JAX installation and that replacing it withnp.savezworks as expected.mo_net/train/augment.py (4)
3-33: JAX migration looks good for imports and scale function!The imports are correctly updated to use JAX, and the scale function properly uses JAX's immutable array operations with
.at[].set().
36-63: Rotate function correctly migrated to JAX.The function properly uses JAX operations and immutable array updates.
66-94: Shear function properly uses JAX operations.The migration maintains the correct logic while using JAX's immutable paradigm.
97-123: Translate function correctly implements JAX patterns.The function properly handles immutable array operations.
mo_net/model/layer/linear.py (1)
300-341: Excellent use of JIT compilation for performance optimization!The JIT compilation is properly applied to computationally intensive operations (forward pass, gradient computation, and gradient clipping), which will significantly improve performance on GPUs/TPUs.
mo_net/cli.py (1)
317-317: JAX integration in CLI looks good!The random state initialization and array operations are correctly migrated to use JAX.
Also applies to: 411-413
There was a problem hiding this comment.
Actionable comments posted: 2
🔭 Outside diff range comments (1)
mo_net/samples/cbow.py (1)
4-4: Fix variable name conflict forrandommoduleThe
randommodule from Python standard library (line 4) is being shadowed by JAX's random module (line 16), which could lead to confusion and potential bugs when using random functions elsewhere in the code.import functools import random import re import time from collections import Counter, defaultdict from collections.abc import Collection, Iterator, Sequence from dataclasses import dataclass from itertools import chain from pathlib import Path from typing import Callable, ParamSpec, TypeVar, assert_never, cast import click import jax.numpy as jnp -import jax.random as random +import jax.random as jax_random import msgpack # type: ignore[import-untyped]This change requires updating all JAX random calls to use
jax_randominstead ofrandom.Also applies to: 14-16
🧹 Nitpick comments (1)
mo_net/tests/model/layer/test_embedding.py (1)
404-420: Address the skipped serialization testThis test is marked as skipped due to layer name collision issues. Consider implementing a proper fixture or test isolation mechanism to enable this important test.
Would you like me to help implement a solution for the layer name collision issue, perhaps using pytest fixtures with unique layer IDs or a test cleanup mechanism?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
mo_net/model/layer/embedding.py(5 hunks)mo_net/samples/cbow.py(8 hunks)mo_net/tests/model/layer/test_embedding.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- mo_net/model/layer/embedding.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
mo_net/tests/model/layer/test_embedding.py (2)
mo_net/model/layer/embedding.py (16)
random(77-79)Embedding(119-287)vocab_size(286-287)parameters(259-260)Parameters(26-105)of(98-99)cache(255-256)update_parameters(244-249)xavier(82-87)he(90-95)empty_gradient(221-224)serialize(235-242)deserialize(133-141)write_serialized_parameters(266-270)read_serialized_parameters(272-279)reinitialise(226-233)mo_net/model/layer/base.py (2)
BadLayerId(51-51)layer_id(90-92)
🪛 Ruff (0.11.9)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: typecheck
🔇 Additional comments (5)
mo_net/samples/cbow.py (4)
65-71: LGTM!Correct migration from NumPy to JAX for the regularization loss computation.
187-206: LGTM!Proper migration to JAX arrays for the training set generation, including correct type annotations and one-hot encoding.
267-270: LGTM!Correct type annotation update for JAX arrays.
655-657: LGTM!Correct migration to JAX operations for cosine similarity calculation.
mo_net/tests/model/layer/test_embedding.py (1)
1-644: Excellent comprehensive test coverage!This test suite provides thorough coverage of the Embedding layer functionality with:
- Well-structured parametrized tests using dataclasses
- Proper JAX array usage throughout
- Good coverage of edge cases and error conditions
- Tests for all major features including gradient clipping, parameter freezing, and caching
The only minor issue is the skipped serialization test that needs attention.
mo_net/samples/cbow.py
Outdated
| # Use random choice to select random words | ||
| word_indices = random.choice( | ||
| random.PRNGKey(42), | ||
| len(vocab.vocab), | ||
| shape=(min(num_words, len(vocab)),), | ||
| replace=False, | ||
| ) | ||
| random_words = [list(vocab.vocab)[int(i)] for i in word_indices] |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Update random module references and improve readability
After fixing the import conflict, update these references. Also, the comment could be more accurate.
- # Use random choice to select random words
- word_indices = random.choice(
- random.PRNGKey(42),
+ # Use JAX random choice to select random words
+ word_indices = jax_random.choice(
+ jax_random.PRNGKey(42),
len(vocab.vocab),
shape=(min(num_words, len(vocab)),),
replace=False,
)
random_words = [list(vocab.vocab)[int(i)] for i in word_indices]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Use random choice to select random words | |
| word_indices = random.choice( | |
| random.PRNGKey(42), | |
| len(vocab.vocab), | |
| shape=(min(num_words, len(vocab)),), | |
| replace=False, | |
| ) | |
| random_words = [list(vocab.vocab)[int(i)] for i in word_indices] | |
| # Use JAX random choice to select random words | |
| word_indices = jax_random.choice( | |
| jax_random.PRNGKey(42), | |
| len(vocab.vocab), | |
| shape=(min(num_words, len(vocab)),), | |
| replace=False, | |
| ) | |
| random_words = [list(vocab.vocab)[int(i)] for i in word_indices] |
🤖 Prompt for AI Agents
In mo_net/samples/cbow.py around lines 632 to 639, update the random module
references to use the correct imported random module (likely from JAX or numpy)
instead of the standard Python random. Also, revise the comment to more
accurately describe the operation, such as "Use random.choice with PRNGKey to
select unique random word indices." Ensure the code uses the appropriate random
function signature and parameters for selecting unique indices without
replacement.
- Batcher copies entire arrays; use index batcher. - Add SparseCategoricalCrossentropyOutputLayer - Remove branching in batcher
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
mo_net/samples/cbow.py (1)
4-4: Fix import conflict by aliasing JAX random module.There's a naming conflict between the standard library
randommodule (line 4) and JAX'srandommodule (line 16). This shadows the standard library import and causes the static analysis error.import random -import jax.random as random +import jax.random as jax_randomThen update all JAX random usages accordingly throughout the file.
Also applies to: 16-16
♻️ Duplicate comments (5)
mo_net/samples/cbow.py (2)
593-595: Update random module references after fixing import conflict.After renaming the JAX random import to avoid shadowing, update these references accordingly.
- predicted_token_id = random.choice( - random.PRNGKey(42), valid_indices, shape=(), p=valid_probs + predicted_token_id = jax_random.choice( + jax_random.PRNGKey(42), valid_indices, shape=(), p=valid_probs )
634-641: Update random module references and improve readability.After fixing the import conflict, update these references. Also, the comment could be more accurate.
- # Use random choice to select random words - word_indices = random.choice( - random.PRNGKey(42), + # Use JAX random choice to select random words + word_indices = jax_random.choice( + jax_random.PRNGKey(42), len(vocab.vocab), shape=(min(num_words, len(vocab)),), replace=False, )mo_net/model/layer/convolution.py (1)
126-137: Consider parameterizing PRNG keys for better reproducibility.The random initialization methods use hardcoded PRNG keys (0, 1), which could lead to identical initializations across different model instances or training runs.
Consider accepting an optional PRNG key parameter and splitting it appropriately for weights and biases initialization.
mo_net/model/layer/linear.py (2)
13-14: Fix PRNGKey type annotation tojax.Array.The JAX PRNG key should be annotated as
jax.Array, notrandom.PRNGKey.-from jax import jit +from jax import jit, Array- key: random.PRNGKey | None = None + key: Array | None = NoneApply this change to all method signatures that use PRNG keys.
Also applies to: 104-105, 118-119, 131-132, 148-149, 263-263
351-351: Wrap return value to match expected type.The backward propagation should return
D[Activations]but currently returns a bare array.- return dX + return d(Activations(dX))
🧹 Nitpick comments (2)
mo_net/train/trainer/trainer.py (1)
309-316: Consider simplifying nested if statements.The static analysis tool correctly identifies that the nested if statements can be combined for better readability.
- if self._training_parameters.monotonic: - if self._last_update is not None: - self._revert_training_step() + if self._training_parameters.monotonic and self._last_update is not None: + self._revert_training_step()mo_net/model/layer/output.py (1)
168-168: Remove unused variable.The static analysis tool correctly identifies that
vocab_sizeis assigned but never used.- vocab_size = self._input_dimensions[0] - # Use advanced indexing to efficiently compute the gradient
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
mo_net/cli.py(7 hunks)mo_net/model/layer/convolution.py(9 hunks)mo_net/model/layer/embedding.py(7 hunks)mo_net/model/layer/linear.py(8 hunks)mo_net/model/layer/output.py(4 hunks)mo_net/samples/cbow.py(10 hunks)mo_net/train/batcher.py(2 hunks)mo_net/train/trainer/trainer.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- mo_net/train/batcher.py
- mo_net/model/layer/embedding.py
- mo_net/cli.py
🧰 Additional context used
🧬 Code Graph Analysis (4)
mo_net/model/layer/output.py (2)
mo_net/model/model.py (1)
backward_prop(270-274)mo_net/protos.py (1)
D(24-32)
mo_net/model/layer/convolution.py (6)
mo_net/model/layer/embedding.py (5)
random(77-79)Parameters(26-105)_clip_gradient_impl(127-131)_no_clip_gradient(134-136)_backward_prop(220-230)mo_net/model/layer/linear.py (5)
random(103-114)Parameters(35-175)_clip_gradient_impl(197-201)_no_clip_gradient(204-206)_backward_prop(329-351)mo_net/model/layer/batch_norm/batch_norm_2d.py (1)
_backward_prop(266-323)mo_net/model/layer/base.py (1)
_backward_prop(107-111)mo_net/model/layer/dropout.py (1)
_backward_prop(88-99)mo_net/protos.py (1)
D(24-32)
mo_net/model/layer/linear.py (8)
mo_net/model/layer/embedding.py (10)
random(77-79)Parameters(26-105)xavier(82-87)he(90-95)of(98-99)from_bytes(101-105)_clip_gradient_impl(127-131)_no_clip_gradient(134-136)parameters(270-271)empty_gradient(232-235)mo_net/model/layer/convolution.py (10)
random(117-138)Parameters(34-199)xavier(154-171)he(174-189)from_bytes(191-199)_clip_gradient_impl(212-216)_no_clip_gradient(219-221)parameters(453-454)ones(141-151)empty_gradient(440-446)mo_net/functions.py (1)
identity(15-16)mo_net/model/layer/base.py (3)
ParametrisedHidden(114-147)input_dimensions(82-83)output_dimensions(86-87)mo_net/protos.py (8)
D(24-32)GradLayer(118-127)SupportsGradientOperations(73-88)d(157-158)input_dimensions(176-176)output_dimensions(179-179)parameters(120-120)empty_gradient(127-127)mo_net/model/layer/batch_norm/batch_norm.py (4)
Parameters(27-167)from_bytes(159-167)parameters(372-373)empty_gradient(343-349)mo_net/model/layer/layer_norm.py (4)
Parameters(27-167)from_bytes(159-167)parameters(326-327)empty_gradient(297-303)mo_net/model/layer/batch_norm/batch_norm_2d.py (3)
from_bytes(122-130)parameters(340-341)empty_gradient(331-337)
mo_net/samples/cbow.py (4)
mo_net/model/layer/embedding.py (3)
random(77-79)parameters(270-271)vocab_size(297-298)mo_net/model/model.py (3)
output(398-399)input_dimensions(70-71)forward_prop(263-268)mo_net/model/layer/output.py (1)
SparseCategoricalCrossentropyOutputLayer(125-188)mo_net/model/base.py (1)
forward_prop(34-34)
🪛 Ruff (0.11.9)
mo_net/model/layer/output.py
168-168: Local variable vocab_size is assigned to but never used
Remove assignment to unused variable vocab_size
(F841)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
mo_net/train/trainer/trainer.py
309-310: Use a single if statement instead of nested if statements
(SIM102)
323-323: Use a context manager for opening files
(SIM115)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: typecheck
🔇 Additional comments (10)
mo_net/train/trainer/trainer.py (2)
8-8: Excellent JAX migration implementation.The migration to JAX arrays and the transition from
BatchertoIndexBatcheris well-implemented. The addition of transform function support and proper type annotations usingjnp.ndarrayaligns perfectly with the broader codebase migration to JAX.Also applies to: 21-21, 95-95, 109-112, 123-123, 125-128
284-289: Well-implemented batch processing with transforms.The index-based batching with optional transform application is a clean implementation that maintains flexibility while integrating with the new
IndexBatcher.mo_net/samples/cbow.py (1)
15-15: Excellent JAX migration and new output layer integration.The migration to JAX arrays and functions is well-executed. The integration of
SparseCategoricalCrossentropyOutputLayeris appropriate for the CBOW model's sparse categorical cross-entropy loss. The numerical operations now use JAX primitives correctly.Also applies to: 27-27, 69-69, 189-189, 206-206, 270-270, 259-261, 577-597, 657-659
mo_net/model/layer/output.py (2)
6-7: Excellent JAX migration implementation.The migration to JAX imports and array types is correctly implemented. The type annotations are consistent with the broader codebase migration.
Also applies to: 41-42, 48-48, 77-77, 84-84, 88-88
125-188: Well-designed sparse categorical cross-entropy layer.The new
SparseCategoricalCrossentropyOutputLayeris excellently implemented:
- Efficiently computes gradients using JAX advanced indexing
- Avoids creating full one-hot matrices
- Proper caching of both logits and softmax outputs
- Correct gradient computation:
softmax - one_hot(targets)mo_net/model/layer/convolution.py (3)
7-8: Excellent JAX migration with modern convolution implementation.The migration to JAX is comprehensive and well-executed:
- Proper use of
jax.lax.conv_general_dilatedfor efficient convolutions- JAX array types throughout
- JAX random initialization methods
Also applies to: 35-36, 124-137, 149-150, 162-171, 181-189
210-221: Well-implemented gradient clipping with JAX JIT.The gradient clipping implementation follows the same pattern as other layers and properly uses JIT compilation for performance.
320-435: Sophisticated convolution and backpropagation implementation.The forward and backward propagation implementation using JAX's
lax.conv_general_dilatedand related functions is sophisticated and correct. The gradient computation for weights, biases, and inputs is properly implemented using JAX primitives.mo_net/model/layer/linear.py (2)
11-14: Excellent JAX migration with performance optimizations.The migration to JAX is well-implemented with good performance optimizations:
- JIT compilation for forward and backward passes
- Proper gradient clipping implementation
- JAX array types throughout
- Efficient gradient computation
Also applies to: 36-37, 195-206, 315-324, 337-351
272-282: Clean default parameter initialization pattern.The default parameter initialization using
partial(Parameters.xavier, key=key)is a clean and flexible approach that maintains backward compatibility while supporting JAX's explicit random key management.
mo_net/train/trainer/trainer.py
Outdated
| pass | ||
| case never: | ||
| assert_never(never) | ||
| self._model.dump(open(self._model_checkpoint_path, "wb")) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Use context manager for file operations.
The static analysis tool correctly identifies that file operations should use context managers for proper resource management.
- self._model.dump(open(self._model_checkpoint_path, "wb"))
+ with open(self._model_checkpoint_path, "wb") as f:
+ self._model.dump(f)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self._model.dump(open(self._model_checkpoint_path, "wb")) | |
| with open(self._model_checkpoint_path, "wb") as f: | |
| self._model.dump(f) |
🧰 Tools
🪛 Ruff (0.11.9)
323-323: Use a context manager for opening files
(SIM115)
🤖 Prompt for AI Agents
In mo_net/train/trainer/trainer.py at line 323, the file is opened without a
context manager which can lead to resource leaks. Replace the open call with a
with statement to ensure the file is properly closed after dumping the model.
Use "with open(self._model_checkpoint_path, 'wb') as f:" and then call
self._model.dump(f) inside the block.
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
mo_net/samples/cbow.py (1)
4-16: Fix import naming conflict between Python random and JAX random.The import on line 16 shadows the Python
randommodule imported on line 4, which could cause confusion and potential runtime errors.Apply this diff to resolve the naming conflict:
import random +import jax.random as jax_random -import jax.random as randomThen update all JAX random usage throughout the file to use
jax_randominstead ofrandom.
♻️ Duplicate comments (4)
mo_net/train/trainer/trainer.py (1)
330-330: Use context manager for file operations.mo_net/samples/cbow.py (2)
579-597: Update random module references after fixing import conflict.
636-643: Update random module references after fixing import conflict.mo_net/model/model.py (1)
12-12: Remove duplicate import of Callable.
🧹 Nitpick comments (1)
mo_net/model/model.py (1)
263-268: Review direct cache access and consider encapsulation.The explicit iteration approach is good, but accessing
_cache["output_activations"]directly breaks encapsulation. Consider if this cached value is actually needed or if returningactivationswould suffice.If the cached activations are specifically required, consider adding a proper accessor method to the output layer:
def forward_prop(self, X: jnp.ndarray) -> Activations: activations = self.input_layer.forward_prop(X) for module in self.hidden_modules: activations = module.forward_prop(input_activations=activations) return self.output_module.forward_prop(input_activations=activations)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
mo_net/cli.py(10 hunks)mo_net/functions.py(1 hunks)mo_net/model/base.py(3 hunks)mo_net/model/layer/embedding.py(8 hunks)mo_net/model/model.py(6 hunks)mo_net/samples/cbow.py(10 hunks)mo_net/train/trainer/trainer.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- mo_net/model/base.py
- mo_net/model/layer/embedding.py
- mo_net/cli.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
mo_net/model/model.py (6)
mo_net/functions.py (1)
identity(34-35)mo_net/protos.py (2)
forward_prop(132-132)D(24-32)mo_net/model/module/base.py (5)
forward_prop(45-52)forward_prop(162-165)output_layer(199-200)backward_prop(109-114)backward_prop(167-172)mo_net/model/layer/base.py (2)
forward_prop(66-76)backward_prop(102-104)mo_net/model/layer/output.py (1)
backward_prop(41-42)mo_net/model/layer/input.py (1)
backward_prop(23-24)
🪛 Ruff (0.11.9)
mo_net/model/model.py
12-12: Redefinition of unused Callable from line 4
(F811)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
mo_net/train/trainer/trainer.py
312-313: Use a single if statement instead of nested if statements
(SIM102)
330-330: Use a context manager for opening files
(SIM115)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: typecheck
🔇 Additional comments (22)
mo_net/functions.py (5)
1-11: LGTM! Clean migration to JAX types and functions.The import changes and type definitions properly establish JAX as the numerical backend. The
TransformFnandLossFntype aliases provide clear interfaces for the refactored architecture.
19-22: LGTM! Well-implemented sparse cross entropy.The sparse cross entropy function correctly handles integer labels by indexing into predictions. The implementation is efficient and follows JAX best practices.
25-31: LGTM! Clear loss function factory with proper error handling.The function provides a clean interface for loss function selection with appropriate error handling for unknown loss functions.
34-52: LGTM! Excellent simplification of activation function handling.The refactor from class-based to dictionary-based activation functions significantly reduces complexity while maintaining type safety. Using JAX native functions directly is more efficient and idiomatic.
55-67: LGTM! Improved error handling with standard exceptions.Replacing
click.BadParameterwith standardValueErrormakes the function more reusable outside CLI contexts while maintaining clear error messages.mo_net/train/trainer/trainer.py (7)
8-22: LGTM! Proper JAX imports and updated dependencies.The imports correctly establish JAX as the numerical backend and import the necessary types from the refactored functions module.
108-129: LGTM! Enhanced constructor with loss function flexibility.The addition of the
loss_fnparameter enables flexible loss function selection, and the switch toIndexBatcheraligns with the JAX migration for more efficient batch handling.
233-235: LGTM! Proper integration of configurable loss function.The validation loss computation correctly uses the injected loss function, maintaining consistency with the training loss computation.
287-293: LGTM! Efficient batch handling with optional transforms.The new batch handling approach using indices is more memory-efficient and allows for optional data transformations. The JAX array slicing is performant and idiomatic.
312-322: LGTM! Monotonic training logic properly integrated.The monotonic training implementation correctly manages the last update state and integrates well with the training step context.
323-337: LGTM! Improved validation and checkpointing logic.The validation loss computation and model checkpointing are properly scoped to epoch boundaries, improving training efficiency.
370-371: LGTM! Updated method signatures for JAX arrays.The type annotations correctly reflect the migration to JAX arrays while maintaining clear interfaces.
mo_net/samples/cbow.py (5)
23-28: LGTM! Proper imports for JAX migration.The imports correctly bring in the sparse cross entropy loss function and the new output layer type consistent with the JAX migration.
70-207: LGTM! Correct type annotations and JAX array usage.The type changes from
np.ndarraytojnp.ndarrayand the use ofjnp.arrayfor array creation are consistent with the JAX migration.
260-262: LGTM! Appropriate output layer for sparse cross entropy.The switch to
SparseCategoricalCrossentropyOutputLayeraligns with the sparse cross entropy loss function used in training.
500-500: LGTM! Explicit loss function specification.Passing the specific loss function to the trainer enables the flexible loss function architecture implemented in the trainer.
659-661: LGTM! Proper JAX operations for similarity computation.The cosine similarity computation using JAX dot product and norm functions is correct and efficient.
mo_net/model/model.py (5)
6-26: LGTM! Proper imports for JAX migration.The import of
partial, JAX numpy, and the new functions (LossFn,identity) correctly support the refactored architecture.
83-110: LGTM! Updated activation function types for JAX callables.The type change from protobuf enum to JAX callable correctly reflects the simplified activation function architecture. Using
identityas the default is appropriate.
270-274: LGTM! Clean backward propagation with explicit iteration.The explicit iteration approach is clearer than the previous functional reduce pattern and correctly propagates gradients backwards through the network.
320-321: LGTM! Correct JAX usage for prediction.The switch to
jnp.argmaxis appropriate for the JAX migration and maintains the same prediction logic.
323-329: LGTM! Flexible loss computation with proper aggregation.The refactored
compute_lossmethod correctly accepts a loss function parameter and properly aggregates the primary loss with registered loss contributors. This enables the flexible loss function selection implemented throughout the codebase.
There was a problem hiding this comment.
Actionable comments posted: 5
🔭 Outside diff range comments (2)
mo_net/samples/cbow.py (2)
4-4: Fix import conflict - random module is redefined.The
randommodule from Python standard library (line 4) is shadowed by JAX's random module (line 16), causing confusion and errors.import random # ... other imports ... -import jax.random as random +import jax.random as jax_randomAlso applies to: 16-16
66-72: Convert JAX array to Python float for type consistency.The method returns a JAX array but the return type annotation expects a float.
def compute_regularisation_loss(self) -> float: return ( 0.5 * self._lambda * jnp.sum(self._layer.parameters.embeddings**2) / self._batch_size - ) + ).item()
♻️ Duplicate comments (3)
mo_net/train/trainer/trainer.py (1)
330-330: Use context manager for file operations.mo_net/samples/cbow.py (2)
594-596: Update random module references after fixing import conflict
635-642: Update random module references and improve readability
🧹 Nitpick comments (1)
mo_net/train/trainer/trainer.py (1)
312-315: Simplify nested if statements.Combine the nested conditions for better readability.
- if self._training_parameters.monotonic: - if self._last_update is not None: - self._revert_training_step() + if self._training_parameters.monotonic and self._last_update is not None: + self._revert_training_step()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
mo_net/cli.py(11 hunks)mo_net/data.py(3 hunks)mo_net/model/layer/average.py(3 hunks)mo_net/samples/cbow.py(11 hunks)mo_net/samples/cnn.py(1 hunks)mo_net/train/trainer/trainer.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- mo_net/data.py
- mo_net/cli.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
mo_net/samples/cnn.py (20)
mo_net/device.py (2)
print_device_info(169-177)set_default_device(116-157)mo_net/functions.py (1)
sparse_cross_entropy(19-22)mo_net/log.py (2)
LogLevel(12-18)setup_logging(21-23)mo_net/model/layer/activation.py (1)
Activation(19-100)mo_net/model/layer/batch_norm/batch_norm_2d.py (1)
BatchNorm2D(148-372)mo_net/model/layer/convolution.py (1)
Convolution2D(206-497)mo_net/model/layer/linear.py (1)
Linear(189-414)mo_net/model/model.py (2)
output(400-401)Model(50-401)mo_net/model/layer/output.py (1)
SparseCategoricalCrossentropyOutputLayer(125-188)mo_net/model/layer/pool.py (1)
MaxPooling2D(18-154)mo_net/model/layer/reshape.py (1)
Flatten(41-72)mo_net/model/module/base.py (2)
Output(126-200)output_layer(199-200)mo_net/protos.py (1)
NormalisationType(188-191)mo_net/cli.py (5)
train(325-485)training_options(78-262)wrapper(72-73)wrapper(259-260)cli(310-310)mo_net/samples/cbow.py (9)
train(405-517)get_name(212-213)get_name(277-278)get_description(216-217)get_description(281-282)create(220-264)training_options(316-394)wrapper(391-392)cli(398-400)mo_net/config.py (1)
TrainingParameters(8-50)mo_net/train/backends/log.py (1)
SqliteBackend(118-219)mo_net/train/run.py (2)
TrainingRun(6-70)seed(30-31)mo_net/train/trainer/trainer.py (4)
BasicTrainer(97-391)TrainingFailed(37-41)TrainingSuccessful(32-33)get_optimizer(50-88)mo_net/data.py (5)
SplitConfig(28-68)load_data(113-118)load_data(120-125)load_data(128-143)of(41-42)
mo_net/model/layer/average.py (3)
mo_net/model/layer/embedding.py (1)
_backward_prop(223-233)mo_net/model/layer/batch_norm/batch_norm.py (1)
_backward_prop(294-341)mo_net/protos.py (1)
D(24-32)
🪛 Ruff (0.11.9)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
mo_net/model/layer/average.py
74-77: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
mo_net/train/trainer/trainer.py
312-313: Use a single if statement instead of nested if statements
(SIM102)
330-330: Use a context manager for opening files
(SIM115)
🪛 GitHub Actions: CI
mo_net/samples/cbow.py
[error] 16-16: mypy: Name "random" already defined (by an import) [no-redef]
[error] 68-68: mypy: Incompatible return value type (got "Array", expected "float") [return-value]
[error] 594-594: mypy: Too many arguments and unexpected keyword arguments "shape", "p" [call-arg]
[error] 595-595: mypy: Module has no attribute "PRNGKey" [attr-defined]
[error] 636-636: mypy: Too many arguments and unexpected keyword arguments "shape", "replace" [call-arg]
[error] 637-637: mypy: Module has no attribute "PRNGKey" [attr-defined]
mo_net/model/layer/average.py
[error] mypy internal error: AssertionError: Internal error: method must be called on parsed file only
mo_net/train/trainer/trainer.py
[error] 332-332: mypy: Incompatible types in assignment (expression has type "CheckFailed | None", variable has type "CheckFailed") [assignment]
[error] 355-355: mypy: "BasicTrainer" has no attribute "_compute_loss" [attr-defined]
🔇 Additional comments (12)
mo_net/model/layer/average.py (3)
6-6: LGTM! JAX import aligns with migration objectives.The import change from NumPy to JAX NumPy is consistent with the PR's goal of migrating to JAX as the numerical backend.
23-25: LGTM! Method signature maintains API consistency.The addition of the
freeze_parametersparameter (even though unused) likely maintains consistency with other layer deserialization methods across the codebase. Explicitly deleting unused parameters is a good practice.
62-62: LGTM! JAX mean function correctly replaces NumPy.The replacement of
np.meanwithjnp.meanis correct and maintains the same functionality with JAX's numerical backend.mo_net/train/trainer/trainer.py (4)
8-14: LGTM! JAX imports are correctly added.The migration from NumPy to JAX arrays is properly reflected in the imports.
106-112: Function signature properly updated for JAX migration.The data arrays are correctly typed as
jnp.ndarrayand the addition of the configurableloss_fnparameter aligns with the broader framework changes.
126-129: Batching mechanism correctly updated to use IndexBatcher.The switch to
IndexBatcheris appropriate for JAX arrays as it yields indices for more efficient array slicing.
287-293: Batch indexing and transformation logic is correct.The updated batching approach properly indexes into the JAX arrays and correctly applies the optional transform function after batching.
mo_net/samples/cbow.py (2)
422-423: Device setup correctly implemented.The JAX device configuration is properly handled with the new device management functions.
499-499: Loss function correctly configured.The sparse cross-entropy loss function is properly passed to the trainer, aligning with the new configurable loss function architecture.
mo_net/samples/cnn.py (3)
44-173: CNN architecture is well-designed and properly uses JAX layers.The model follows best practices with three convolutional blocks (conv -> batch norm -> ReLU -> pooling), proper flattening, and dense layers. The use of JAX activation functions and sparse categorical cross-entropy output layer aligns with the framework migration.
291-330: Training configuration is properly set up.The training parameters, optimizer selection, and trainer initialization correctly use the JAX-migrated framework components with appropriate loss function configuration.
335-344: Result handling is well-implemented.Proper use of pattern matching to handle success and failure cases with appropriate logging and file management.
mo_net/train/trainer/trainer.py
Outdated
| if (check := self._post_epoch(L_val)) is not None: | ||
| return TrainingFailed( | ||
| model_checkpoint_path=self._model_checkpoint_path, | ||
| message=check.message, | ||
| model_checkpoint_save_epoch=self._L_val_min_epoch, | ||
| ) |
There was a problem hiding this comment.
Fix type annotation for check variable.
The check variable can be None as returned by _post_epoch, but the condition only handles the non-None case correctly.
🧰 Tools
🪛 GitHub Actions: CI
[error] 332-332: mypy: Incompatible types in assignment (expression has type "CheckFailed | None", variable has type "CheckFailed") [assignment]
🤖 Prompt for AI Agents
In mo_net/train/trainer/trainer.py around lines 332 to 337, the variable 'check'
is assigned the result of self._post_epoch(L_val), which can be None, but the
current type annotation does not reflect this possibility. Update the type
annotation of 'check' to indicate it can be either the expected return type of
_post_epoch or None, ensuring the code correctly handles both cases and the type
checker is satisfied.
| X_train, Y_train, _, __ = load_data( | ||
| MNIST_TRAIN_URL, | ||
| split=SplitConfig.of(train_split, 0), | ||
| one_hot=False, | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Remove redundant data splitting.
You're using SplitConfig to load pre-split data but then manually splitting again. Either use the split from load_data or load all data without splitting.
- X_train, Y_train, _, __ = load_data(
+ # Load full training data without split
+ X_train_full, Y_train_full = load_data(
MNIST_TRAIN_URL,
- split=SplitConfig.of(train_split, 0),
one_hot=False,
)
# ... reshape code ...
- train_size = int(0.8 * len(X_train))
- X_train_split = X_train[:train_size]
- Y_train_split = Y_train[:train_size]
- X_val = X_train[train_size:]
- Y_val = Y_train[train_size:]
+ train_size = int(train_split * len(X_train_full))
+ X_train = X_train_full[:train_size]
+ Y_train = Y_train_full[:train_size]
+ X_val = X_train_full[train_size:]
+ Y_val = Y_train_full[train_size:]Alternatively, use the split data directly:
- X_train, Y_train, _, __ = load_data(
+ X_train, Y_train, X_val, Y_val = load_data(
MNIST_TRAIN_URL,
split=SplitConfig.of(train_split, 0),
one_hot=False,
)
# ... reshape code ...
- train_size = int(0.8 * len(X_train))
- X_train_split = X_train[:train_size]
- Y_train_split = Y_train[:train_size]
- X_val = X_train[train_size:]
- Y_val = Y_train[train_size:]
+ # Data is already split, reshape validation data too
+ X_val = X_val.reshape(-1, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE)Also applies to: 310-314
🤖 Prompt for AI Agents
In mo_net/samples/cnn.py around lines 272 to 276, the code redundantly splits
data by using SplitConfig in load_data and then manually splitting again. To fix
this, remove the manual splitting and use the data as returned by load_data
directly, or alternatively load all data without specifying a split and then
split manually. Apply the same fix to lines 310 to 314 to ensure consistency.
There was a problem hiding this comment.
Actionable comments posted: 3
🔭 Outside diff range comments (1)
mo_net/samples/cbow.py (1)
4-4: Fix import conflict between standard Python random and JAX randomThe import of
jax.random as randomconflicts with the standard Pythonrandommodule imported on line 4, causing redefinition errors and making both modules inaccessible by their intended names.import random +import jax.random as jax_random -import jax.random as randomThis resolves the import conflict flagged by static analysis tools.
Also applies to: 16-16
♻️ Duplicate comments (3)
mo_net/samples/cbow.py (2)
594-596: Update random module references after fixing import conflictAfter renaming the JAX random import to avoid shadowing, update these references accordingly. The function signature for
jax.random.choicealso needs correction.- predicted_token_id = random.choice( - random.PRNGKey(42), valid_indices, shape=(), p=valid_probs + predicted_token_id = jax_random.choice( + jax_random.PRNGKey(42), valid_indices, shape=(), p=valid_probs )
635-642: Update random module references and improve readabilityAfter fixing the import conflict, update these references. The function signature for
jax.random.choicealso needs correction.- # Use random choice to select random words - word_indices = random.choice( - random.PRNGKey(42), + # Use JAX random choice to select random words + word_indices = jax_random.choice( + jax_random.PRNGKey(42), len(vocab.vocab), shape=(min(num_words, len(vocab)),), replace=False, )mo_net/train/trainer/trainer.py (1)
330-330: Use context manager for file operations.
🧹 Nitpick comments (3)
mo_net/model/layer/average.py (1)
70-80: Excellent refactoring of backward propagation logic.The new implementation using JAX idioms is mathematically correct and more efficient:
- Properly expands dimensions and divides by the averaging factor
- Uses
jnp.broadcast_tofor clean gradient broadcasting- Maintains the correct gradient flow for the average operation
However, address the static analysis hint about the
zip()function.Apply this diff to address the static analysis hint:
- for ax, factor in zip( + for ax, factor in zip( sorted(ax + 1 for ax in self._axis), tuple(self._cache["input_shape"][ax + 1] for ax in self._axis), + strict=True, ):The
strict=Trueparameter ensures that the iterables have the same length, which is the expected behavior here since both sequences are derived from the sameself._axis.mo_net/train/trainer/trainer.py (1)
312-315: Simplify nested if statements.The nested if statements can be combined for better readability.
Apply this diff to simplify:
- if self._training_parameters.monotonic: - if self._last_update is not None: - self._revert_training_step() + if self._training_parameters.monotonic and self._last_update is not None: + self._revert_training_step()mo_net/samples/cnn.py (1)
44-173: Well-structured CNN architecture for MNIST.The model definition is clean with proper dimension tracking through variable assignments. The architecture (Conv→BN→ReLU→Pool)×3→Flatten→Dense→Output is appropriate for MNIST.
Consider parameterizing the architecture constants (kernel counts, sizes, dense units) to make the model more flexible:
@classmethod def create( cls, *, tracing_enabled: bool = False, conv_channels: tuple[int, ...] = (32, 64, 128), dense_units: int = 512, ) -> CNNModel:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
mo_net/cli.py(11 hunks)mo_net/data.py(3 hunks)mo_net/model/layer/average.py(3 hunks)mo_net/samples/cbow.py(11 hunks)mo_net/samples/cnn.py(1 hunks)mo_net/train/trainer/trainer.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- mo_net/data.py
- mo_net/cli.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
mo_net/model/layer/average.py (3)
mo_net/model/layer/embedding.py (1)
_backward_prop(223-233)mo_net/model/layer/batch_norm/batch_norm.py (1)
_backward_prop(294-341)mo_net/protos.py (1)
D(24-32)
🪛 Ruff (0.11.9)
mo_net/model/layer/average.py
74-77: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
mo_net/train/trainer/trainer.py
312-313: Use a single if statement instead of nested if statements
(SIM102)
330-330: Use a context manager for opening files
(SIM115)
🪛 GitHub Actions: CI
mo_net/model/layer/average.py
[error] mypy internal error: AssertionError: Internal error: method must be called on parsed file only
mo_net/samples/cbow.py
[error] 16-16: mypy: Name "random" already defined (by an import) [no-redef]
[error] 68-68: mypy: Incompatible return value type (got "Array", expected "float") [return-value]
[error] 594-594: mypy: Too many arguments and unexpected keyword arguments "shape" and "p" [call-arg]
[error] 595-595: mypy: Module has no attribute "PRNGKey" [attr-defined]
[error] 636-636: mypy: Too many arguments and unexpected keyword arguments "shape" and "replace" [call-arg]
[error] 637-637: mypy: Module has no attribute "PRNGKey" [attr-defined]
mo_net/train/trainer/trainer.py
[error] 332-332: mypy: Incompatible types in assignment (expression has type "CheckFailed | None", variable has type "CheckFailed") [assignment]
[error] 355-355: mypy: "BasicTrainer" has no attribute "_compute_loss" [attr-defined]
🔇 Additional comments (15)
mo_net/model/layer/average.py (3)
6-6: LGTM! JAX NumPy import is correct.The migration from NumPy to JAX NumPy is consistent with the PR objective.
23-25: Good interface consistency and parameter handling.The addition of the
freeze_parametersparameter maintains consistency across the layer interface, and explicitly marking unused parameters is good practice.
62-62: Correct JAX migration for mean operation.The change from
np.meantojnp.meanis appropriate for the JAX migration while maintaining the same functionality.mo_net/samples/cbow.py (4)
190-190: LGTM: Type annotations and array operations correctly updatedThe type annotations have been properly updated from
np.ndarraytojnp.ndarray, and the array creation uses JAX'sjnp.arrayfunction. This is consistent with the overall migration to JAX.Also applies to: 207-207
260-262: LGTM: Output layer migration aligns with framework changesThe change from
SoftmaxOutputLayertoSparseCategoricalCrossentropyOutputLayeris consistent with the broader framework migration and the use of sparse cross-entropy loss function.
578-596: LGTM: JAX array operations correctly implementedThe probability calculations and tensor operations have been properly converted to use JAX's NumPy interface. The logic for temperature scaling, top-p sampling, and probability normalization is mathematically correct.
658-660: LGTM: Cosine similarity calculation correctly converted to JAXThe cosine similarity computation using JAX's
jnp.dotandjnp.linalg.normfunctions is correctly implemented and maintains the same mathematical behavior as the NumPy version.mo_net/train/trainer/trainer.py (4)
8-8: LGTM! Imports correctly updated for JAX migration.The imports have been properly updated to use JAX NumPy and the new function types.
Also applies to: 14-14, 22-22
106-129: Constructor properly updated for JAX arrays and configurable loss function.The changes correctly:
- Accept JAX arrays for training and validation data
- Add configurable loss function support
- Use
IndexBatcherfor batch index generation- Rename
transformtotransform_fnfor clarity
233-235: Loss computation correctly updated to use configurable loss function.All loss computations now properly pass the
loss_fnparameter, enabling flexible loss function selection.Also applies to: 316-318, 324-326
332-337: Fix type handling for post-epoch check.The
_post_epochmethod returnsCheckFailed | None, but the type narrowing isn't handled properly.Apply this diff to fix the type handling:
- if (check := self._post_epoch(L_val)) is not None: - return TrainingFailed( - model_checkpoint_path=self._model_checkpoint_path, - message=check.message, - model_checkpoint_save_epoch=self._L_val_min_epoch, - ) + check = self._post_epoch(L_val) + if check is not None: + return TrainingFailed( + model_checkpoint_path=self._model_checkpoint_path, + message=check.message, + model_checkpoint_save_epoch=self._L_val_min_epoch, + )Likely an incorrect or invalid review comment.
mo_net/samples/cnn.py (4)
1-41: Imports and constants are well-organized.Good separation of standard library, third-party, and project imports. Constants are appropriately defined for MNIST dataset dimensions.
176-249: Comprehensive training options with good defaults.The CLI options cover all necessary training parameters with sensible defaults. The decorator pattern allows for easy reuse across different commands.
272-277: Verify data loading split configuration.The code loads training data with a split configuration but then manually splits again later. This might lead to unexpected behavior.
Clarify the intention here - should the
SplitConfiginload_datahandle the train/validation split, or should it be done manually? Currently, both approaches are mixed.
320-330: Training setup correctly uses JAX arrays and sparse cross-entropy loss.The trainer initialization properly passes all required parameters including the sparse cross-entropy loss function, which is appropriate for the sparse categorical output layer.
- Propagate seeds and keys properly - Jit compile affine_transform - Rename to affine_transform2D - Update batcher logic
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
mo_net/samples/cbow.py (1)
4-4: Remove unused import to avoid shadowing.The standard library
randomimport on line 4 is shadowed by the JAX random import on line 16 and is never used.Apply this diff to fix the import shadowing:
-import randomThis removes the unused standard library import since all random operations use JAX's random module.
Also applies to: 16-16
♻️ Duplicate comments (5)
mo_net/train/augment.py (1)
148-160: Fix PRNG key reuse for independent random values.Using the same key for multiple random operations produces identical values in JAX. Split additional keys for independent random values.
Apply this diff to fix the key reuse:
- key1, key2, key3, key4, key5, key6 = random.split(key, 6) + key1, key2, key3, key4, key5, key6, key7 = random.split(key, 7) rotation = random.uniform( key1, (), minval=min_rotation_radians, maxval=max_rotation_radians ) x_shear = random.uniform(key2, (), minval=min_shear, maxval=max_shear) y_shear = random.uniform(key3, (), minval=min_shear, maxval=max_shear) scale_val = random.uniform(key4, (), minval=min_scale, maxval=max_scale) x_offset = random.randint( key5, (), min_translation_pixels, max_translation_pixels ) y_offset = random.randint( - key5, (), min_translation_pixels, max_translation_pixels + key6, (), min_translation_pixels, max_translation_pixels )And update the return statement:
- return X, key6 + return X, key7mo_net/train/trainer/trainer.py (3)
337-337: Use context manager for file operations.The static analysis tool correctly identifies that file operations should use context managers for proper resource management.
Apply this diff:
- self._model.dump(open(self._model_checkpoint_path, "wb")) + with open(self._model_checkpoint_path, "wb") as f: + self._model.dump(f)
358-360: Fix incorrect method call.The method
_compute_lossdoesn't exist onBasicTrainer. It should beself._model.compute_loss.Apply this diff to fix:
- if L_val is None: - L_val = self._compute_loss(X=self._X_val, Y_true=self._Y_val) + if L_val is None: + L_val = self._model.compute_loss(X=self._X_val, Y_true=self._Y_val, loss_fn=self._loss_fn)
236-236: Use context manager for file operations.File operations should use context managers for proper resource management.
Apply this diff:
- self._model.dump(open(self._model_checkpoint_path, "wb")) + with open(self._model_checkpoint_path, "wb") as f: + self._model.dump(f)mo_net/samples/cbow.py (1)
66-72: Fix return type mismatch in regularization computation.The function returns a JAX array but the return type annotation expects a Python float. JAX arrays don't automatically convert to Python scalars.
Apply this diff to convert the JAX scalar to a Python float:
def compute_regularisation_loss(self) -> float: - return ( + return float( 0.5 * self._lambda * jnp.sum(self._layer.parameters.embeddings**2) / self._batch_size )
🧹 Nitpick comments (1)
mo_net/train/trainer/trainer.py (1)
319-322: Simplify nested if statements.Combine the nested if statements into a single condition for better readability.
Apply this diff:
- if self._training_parameters.monotonic: - if self._last_update is not None: - self._revert_training_step() + if self._training_parameters.monotonic and self._last_update is not None: + self._revert_training_step()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
mo_net/cli.py(11 hunks)mo_net/functions.py(1 hunks)mo_net/samples/cbow.py(12 hunks)mo_net/samples/cnn.py(1 hunks)mo_net/samples/mnist/__main__.py(4 hunks)mo_net/train/augment.py(1 hunks)mo_net/train/batcher.py(1 hunks)mo_net/train/trainer/trainer.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- mo_net/samples/mnist/main.py
- mo_net/train/batcher.py
- mo_net/samples/cnn.py
- mo_net/functions.py
- mo_net/cli.py
🧰 Additional context used
🪛 Ruff (0.11.9)
mo_net/samples/cbow.py
16-16: Redefinition of unused random from line 4
(F811)
mo_net/train/trainer/trainer.py
319-320: Use a single if statement instead of nested if statements
(SIM102)
337-337: Use a context manager for opening files
(SIM115)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: typecheck
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (4)
mo_net/samples/mnist/cnn.py (4)
293-297: Redundant data splitting configuration.The
SplitConfig.of(train_split, 0)suggests initial splitting, but lines 333-337 perform manual splitting again. This creates confusion and potential inconsistency.Consider using only one splitting approach:
- X_train, Y_train, _, __ = load_data( - MNIST_TRAIN_URL, - split=SplitConfig.of(train_split, 0), - one_hot=False, - ) + X_train, Y_train, _, __ = load_data( + MNIST_TRAIN_URL, + split=SplitConfig.of(1.0, 0), + one_hot=False, + )
304-306: Duplicate seed generation reduces reproducibility.The seed is generated twice (here and at line 339), which could lead to confusion and reduced reproducibility. Consider using a single seed throughout the function.
seed = time.time_ns() // 1000 logger.info(f"Seed: {seed}") key = jax.random.PRNGKey(seed) # ... model initialization ... -seed = time.time_ns() // 1000 -run = TrainingRun(seed=seed, name=f"cnn_run_{seed}", backend=SqliteBackend()) +run = TrainingRun(seed=seed, name=f"cnn_run_{seed}", backend=SqliteBackend())
333-337: Add validation for train_split parameter.The manual data splitting doesn't validate that
train_splitis within reasonable bounds (0.0 to 1.0), which could lead to runtime errors or unexpected behavior.+if not 0.0 < train_split < 1.0: + raise ValueError(f"train_split must be between 0.0 and 1.0, got {train_split}") + train_size = int(train_split * len(X_train))
366-366: Add error handling for file operations.The file rename operation could fail and should be wrapped in error handling.
- result.model_checkpoint_path.rename(model_output_path) - logger.info(f"Training completed. Model saved to: {model_output_path}") + try: + result.model_checkpoint_path.rename(model_output_path) + logger.info(f"Training completed. Model saved to: {model_output_path}") + except OSError as e: + logger.error(f"Failed to save model to {model_output_path}: {e}")
🧹 Nitpick comments (3)
mo_net/functions.py (1)
55-65: Consider simplifying error handling in parse_activation_fn.The function re-raises ValueError with the same message, which doesn't add value.
try: return get_activation_fn(value) except ValueError as e: - raise ValueError(str(e)) from e + raisemo_net/tests/model/layer/test_embedding.py (1)
407-408: Consider using a proper test fixture for layer name collision.The skip reason suggests this test has issues with layer name collisions. Consider using pytest fixtures with unique names or temporary directories to resolve this.
Would you like me to help design a fixture that ensures unique layer IDs for serialization tests?
mo_net/model/layer/convolution.py (1)
338-402: Comprehensive backward propagation with potential optimization opportunity.The implementation correctly uses JAX operations for gradient computation. However, the kernel loop (lines 387-399) could potentially be vectorized for better performance.
Consider vectorizing the kernel loop:
- dX = jnp.zeros((batch_size, in_channels, in_height, in_width)) - - for k in range(self._n_kernels): - dZ_k = dZ_padded[:, k : k + 1, :, :] - weights_k = weights_flipped[k : k + 1, :, :, :] - weights_k_t = jnp.transpose(weights_k, (1, 0, 2, 3)) - - dX_k = lax.conv_general_dilated( - dZ_k, - weights_k_t, - window_strides=(1, 1), - padding="VALID", - dimension_numbers=("NCHW", "OIHW", "NCHW"), - ) - dX += dX_k + # Vectorized approach using grouped convolution or batch processing + weights_transposed = jnp.transpose(weights_flipped, (1, 0, 2, 3)) + dX = lax.conv_general_dilated( + dZ_padded, + weights_transposed, + window_strides=(1, 1), + padding="VALID", + dimension_numbers=("NCHW", "IOHW", "NCHW"), + feature_group_count=1, + )This would eliminate the loop and leverage JAX's optimized convolution operations more effectively.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
README.md(2 hunks)mo_net/cli.py(14 hunks)mo_net/device.py(3 hunks)mo_net/functions.py(1 hunks)mo_net/model/layer/activation.py(5 hunks)mo_net/model/layer/convolution.py(10 hunks)mo_net/model/layer/embedding.py(10 hunks)mo_net/model/layer/linear.py(10 hunks)mo_net/model/model.py(10 hunks)mo_net/model/module/dense.py(3 hunks)mo_net/model/module/norm.py(3 hunks)mo_net/samples/cbow.py(20 hunks)mo_net/samples/mnist/cnn.py(1 hunks)mo_net/tests/model/layer/test_embedding.py(1 hunks)mo_net/tests/model/layer/test_linear.py(22 hunks)mo_net/tests/model/test_model.py(6 hunks)mo_net/tests/trainer/test_parallel.py(12 hunks)mo_net/train/backends/log.py(4 hunks)pyproject.toml(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (12)
- pyproject.toml
- mo_net/model/module/norm.py
- mo_net/train/backends/log.py
- mo_net/model/module/dense.py
- mo_net/tests/model/layer/test_linear.py
- mo_net/model/layer/embedding.py
- mo_net/tests/model/test_model.py
- mo_net/cli.py
- mo_net/model/layer/linear.py
- mo_net/tests/trainer/test_parallel.py
- mo_net/model/layer/activation.py
- mo_net/samples/cbow.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
mo_net/samples/mnist/cnn.py (16)
mo_net/device.py (2)
print_device_info(162-170)set_default_device(109-150)mo_net/functions.py (1)
sparse_cross_entropy(20-23)mo_net/log.py (2)
LogLevel(12-18)setup_logging(21-23)mo_net/model/layer/activation.py (3)
Activation(18-92)input_dimensions(78-79)output_dimensions(82-83)mo_net/model/layer/batch_norm/batch_norm_2d.py (1)
BatchNorm2D(148-393)mo_net/model/layer/convolution.py (4)
Convolution2D(210-464)random(116-139)Parameters(33-203)he(176-193)mo_net/model/layer/linear.py (5)
Linear(181-394)random(101-112)Parameters(33-167)he(124-130)of(157-158)mo_net/model/layer/output.py (3)
SparseCategoricalCrossentropyOutputLayer(125-172)output_dimensions(91-92)output_dimensions(168-169)mo_net/model/layer/reshape.py (1)
Flatten(41-72)mo_net/model/module/base.py (8)
Hidden(88-123)Output(126-200)input_dimensions(38-39)input_dimensions(185-189)layers(71-72)output_dimensions(42-43)output_dimensions(192-196)output_layer(199-200)mo_net/protos.py (3)
NormalisationType(182-185)input_dimensions(170-170)output_dimensions(173-173)mo_net/train/trainer/trainer.py (4)
train(209-266)BasicTrainer(98-403)TrainingFailed(38-42)TrainingSuccessful(33-34)mo_net/train/augment.py (1)
affine_transform2D(132-177)mo_net/train/backends/log.py (5)
SqliteBackend(133-287)create(32-33)create(82-83)create(156-157)create(298-299)mo_net/train/run.py (2)
TrainingRun(6-70)seed(30-31)mo_net/data.py (5)
SplitConfig(28-68)load_data(113-118)load_data(120-125)load_data(128-143)of(41-42)
🪛 LanguageTool
README.md
[grammar] ~54-~54: Use proper spacing conventions.
Context: ...explains how to enable GPU acceleration. ## Supported Devices - CPU: (Default) ...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~56-~56: Use proper spacing conventions.
Context: ... GPU acceleration. ## Supported Devices - CPU: (Default) - CUDA GPU: NVIDIA ...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~59-~59: Use proper spacing conventions.
Context: ...*: (Default) - CUDA GPU: NVIDIA GPUs ## Installation ## Usage ### Command Line...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~61-~61: Use proper spacing conventions.
Context: ...CUDA GPU**: NVIDIA GPUs ## Installation ## Usage ### Command Line Use the `--devi...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~63-~63: Use proper spacing conventions.
Context: ...: NVIDIA GPUs ## Installation ## Usage ### Command Line Use the --device flag wh...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~65-~65: Use proper spacing conventions.
Context: ...Installation ## Usage ### Command Line Use the --device flag when training: ...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~67-~67: Use proper spacing conventions.
Context: ... Use the --device flag when training: bash uv run train --device auto # The default. Can omit uv run train --device gpu # NVIDIA GPU uv run train --device cpu # CPU only ### Python API ```python from mo_net.device...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~76-~76: Use proper spacing conventions.
Context: ...ce cpu # CPU only ### Python API python from mo_net.device import set_default_device, print_device_info # Set device before creating models or training set_default_device("auto") # or "gpu", "cpu" print_device_info() ``` ### Checking Device Availability Run the in...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~87-~87: Use proper spacing conventions.
Context: ...() ``` ### Checking Device Availability Run the included script to check your de...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~89-~89: Use proper spacing conventions.
Context: ...luded script to check your device setup: bash uv run check_device This will show: - Available JAX devices...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~95-~95: Use proper spacing conventions.
Context: ...uv run check_device ``` This will show: - Available JAX devices - Which device is ...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~97-~97: There might be a mistake here.
Context: ...e ``` This will show: - Available JAX devices - Which device is currently selected - ...
(QB_NEW_EN_OTHER)
[grammar] ~98-~98: There might be a mistake here.
Context: ...JAX devices - Which device is currently selected - A simple test computation ## Perform...
(QB_NEW_EN_OTHER)
[grammar] ~99-~99: Use proper spacing conventions.
Context: ...tly selected - A simple test computation ## Performance Tips 1. **Memory Management...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~101-~101: Use proper spacing conventions.
Context: ...le test computation ## Performance Tips 1. Memory Management: JAX pre-allocates G...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~103-~103: Use proper spacing conventions.
Context: ...ault. To share GPU with other processes: python from mo_net.device import enable_gpu_memory_growth enable_gpu_memory_growth() 2. Mixed Precision: JAX uses float32 by d...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~110-~110: Use proper spacing conventions.
Context: ...default, which is optimal for most GPUs. ## Troubleshooting ### CUDA GPU Not Detect...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~112-~112: Use proper spacing conventions.
Context: ...timal for most GPUs. ## Troubleshooting ### CUDA GPU Not Detected - Ensure CUDA is ...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~114-~114: Use proper spacing conventions.
Context: ...oubleshooting ### CUDA GPU Not Detected - Ensure CUDA is installed: nvidia-smi -...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~117-~117: Use articles correctly.
Context: ...CUDA is installed: nvidia-smi - Check JAX CUDA version matches your CUDA installa...
(QB_NEW_EN_OTHER_ERROR_IDS_000004)
[grammar] ~117-~117: There might be a mistake here.
Context: ...heck JAX CUDA version matches your CUDA installation - Reinstall JAX with correct CUDA versi...
(QB_NEW_EN_OTHER)
[grammar] ~118-~118: Use articles correctly.
Context: ... CUDA installation - Reinstall JAX with correct CUDA version ### Performance Issues -...
(QB_NEW_EN_OTHER_ERROR_IDS_000004)
[grammar] ~118-~118: Place a period at the end of declarative sentences.
Context: ... Reinstall JAX with correct CUDA version ### Performance Issues - Check device is ac...
(QB_NEW_EN_OTHER_ERROR_IDS_000178)
[grammar] ~120-~120: Use proper spacing conventions.
Context: ...ect CUDA version ### Performance Issues - Check device is actually being used: Loo...
(QB_NEW_EN_OTHER_ERROR_IDS_000007)
[grammar] ~122-~122: Use articles correctly.
Context: ...ersion ### Performance Issues - Check device is actually being used: Look for device...
(QB_NEW_EN_OTHER_ERROR_IDS_000004)
[grammar] ~122-~122: There might be a mistake here.
Context: ... used: Look for device info in training logs - Ensure batch size is appropriate for ...
(QB_NEW_EN_OTHER)
[grammar] ~123-~123: Use articles correctly.
Context: ...r device info in training logs - Ensure batch size is appropriate for your GPU memory...
(QB_NEW_EN_OTHER_ERROR_IDS_000004)
[grammar] ~123-~123: There might be a mistake here.
Context: ... batch size is appropriate for your GPU memory - Monitor GPU usage with nvidia-smi (...
(QB_NEW_EN_OTHER)
[grammar] ~124-~124: Place a period at the end of declarative sentences.
Context: ...mi` (NVIDIA) or Activity Monitor (macOS)
(QB_NEW_EN_OTHER_ERROR_IDS_000178)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: typecheck
🔇 Additional comments (26)
mo_net/functions.py (4)
14-17: LGTM! JAX migration and type safety handled correctly.The cross-entropy function correctly uses JAX operations and
.item()to return a Python float, addressing the previous type mismatch issue.
20-23: New sparse cross-entropy implementation looks correct.The sparse cross-entropy function properly handles integer labels by indexing into predictions and applies the same clipping and averaging logic as the standard cross-entropy.
26-32: Well-designed factory function with proper error handling.The
get_loss_fnfunction provides a clean interface for retrieving loss functions with appropriate type hints and error messages.
39-45: Excellent use of JAX native activation functions.The activation functions dictionary leverages JAX's built-in implementations, which are optimized and JIT-compiled. This is much better than custom implementations.
mo_net/model/model.py (4)
84-87: Excellent API improvements with JAX integration.The method signature updates properly integrate JAX random keys, configurable activation functions, dropout probabilities, and regularizers. The default
identityactivation function is sensible.Also applies to: 101-104, 115-118
137-149: Proper JAX random key management in factory functions.The use of
jax.random.split(key)to generate subkeys for each module initialization follows JAX best practices for reproducible random operations.Also applies to: 167-178
292-297: Clear and maintainable forward propagation logic.The refactoring from functional reduction to explicit iteration improves readability while maintaining the same computational flow.
352-358: Flexible loss function integration.The
compute_lossmethod now accepts any loss function, making the model more flexible while properly handling loss contributors.mo_net/tests/model/layer/test_embedding.py (4)
16-25: Excellent test structure with dataclasses.Using dataclasses for test cases provides clear, type-safe test organization and makes the tests highly readable.
Also applies to: 116-126, 219-229
27-114: Comprehensive forward propagation test coverage.The parametrized tests cover simple lookups, single tokens, and batch processing with clear expected outputs. The test data is well-designed to verify correct embedding behavior.
164-200: Thorough gradient accumulation testing.The test cases properly verify that gradients are accumulated for duplicate tokens, which is crucial for embedding layer correctness.
520-536: Excellent validation of gradient accumulation semantics.This test specifically verifies that gradients for the same token are properly accumulated (3.0 = 1.0 + 1.0 + 1.0), which is essential for correct embedding training.
README.md (2)
18-18: Updated training command reflects new sample structure.The change from
uv run train --quickstart mnist_mlptouv run python samples/mnist/cnn.py trainaligns with the new sample organization mentioned in the AI summary.
52-124: Comprehensive GPU setup guide adds significant value.The new GPU setup section provides clear instructions for device selection, API usage, troubleshooting, and performance optimization. This documentation significantly improves the user experience for GPU acceleration.
mo_net/device.py (3)
13-14: Reasonable simplification of device types.Removing "mps" from explicit device types while maintaining Metal device support in auto-selection simplifies the API without breaking functionality for Apple Silicon users.
66-66: Documentation correctly reflects available device types.The docstring updates accurately reflect that only "cpu", "gpu", and "auto" are now explicitly supported, while Metal devices are still handled automatically.
Also applies to: 116-116
83-86: Metal device auto-selection preserved.The auto-selection logic still properly detects and uses Metal devices when available, ensuring Apple Silicon users don't lose functionality despite the removal of explicit "mps" support.
mo_net/samples/mnist/cnn.py (4)
1-48: LGTM: Well-organized imports and constants.The imports are comprehensive and properly structured. The use of JAX aligns with the PR objectives, and the constants are appropriately defined.
50-194: Excellent CNN architecture implementation.The model follows best practices:
- Proper JAX random key management with key splitting
- Well-documented architecture with clear layer progression
- Appropriate use of He initialization for conv layers and Xavier for dense layers
- Good separation of concerns with modular layer construction
197-263: Comprehensive training options configuration.The decorator provides excellent coverage of training hyperparameters with sensible defaults and proper type annotations.
313-371: Well-implemented training pipeline.The training setup demonstrates good practices:
- Comprehensive training parameters configuration
- Proper use of JAX random keys for trainer initialization
- Appropriate data augmentation and loss function selection
- Modern pattern matching for result handling
- Good logging throughout the process
mo_net/model/layer/convolution.py (5)
1-36: Excellent JAX migration for imports and types.The migration from NumPy to JAX is properly executed:
- Appropriate JAX imports added
- Parameter types correctly updated to
jnp.ndarray- Clean removal of NumPy dependencies
115-193: Excellent PRNG key parameterization implementation.The previous concern about hardcoded PRNG keys has been fully addressed. All initialization methods now:
- Require explicit JAX PRNG keys as parameters
- Use proper key splitting with
jax.random.split- Ensure reproducible and varied initialization across model instances
214-225: Well-designed JIT-compiled gradient clipping.The gradient clipping implementation demonstrates good practices:
- JIT compilation for performance optimization
- Clean separation between clipping and no-op functions
- Proper use of JAX operations for numerical stability
322-336: Efficient JAX-based forward propagation.The convolution implementation properly uses:
lax.conv_general_dilatedfor efficient convolution- Correct dimension numbers ("NCHW", "OIHW", "NCHW")
- Proper bias broadcasting with appropriate array indexing
404-465: Properly updated utility methods for JAX compatibility.The remaining methods have been correctly migrated:
empty_gradientusesjnp.zeros_likefor JAX compatibility- Parameter access and update logic maintained
- Serialization methods properly handle JAX arrays
modiase
left a comment
There was a problem hiding this comment.
Reviewed 1 of 8 files at r2, 14 of 52 files at r3, 1 of 3 files at r4, 2 of 7 files at r6, 31 of 42 files at r8, 1 of 2 files at r9, 11 of 12 files at r13, 12 of 13 files at r14, 1 of 1 files at r15, 4 of 4 files at r16, 6 of 6 files at r17, all commit messages.
Reviewable status: all files reviewed, 1 unresolved discussion (waiting on @modiase)
| if len(self._pending_entries) >= self._max_queue_size: | ||
| logger.warning("Log queue is full. Dropping log entry.") | ||
| return |
There was a problem hiding this comment.
We want non-blocking here.
Summary by CodeRabbit
New Features
Enhancements
Bug Fixes
Documentation
Tests
Chores
This change is