Skip to content

Commit f4bc868

Browse files
author
Flax Authors
committed
Merge pull request #4897 from google:pytree-guide
PiperOrigin-RevId: 803235639
2 parents bfa5d8e + 1f10b9d commit f4bc868

37 files changed

+1795
-362
lines changed

.github/workflows/flax_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
8383
tests:
8484
name: Run Tests
85-
needs: [pre-commit, commit-count, test-import]
85+
needs: [pre-commit, commit-count]
8686
runs-on: ubuntu-24.04-16core
8787
strategy:
8888
# Make sure to change `github_check_runs` in `copy.bara.sky` if you change the tests here.

docs_nnx/api_reference/flax.nnx/helpers.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,9 @@ helpers
77

88
.. autoclass:: Sequential
99
:members:
10+
.. autoclass:: List
11+
:members:
12+
.. autoclass:: Dict
13+
:members:
1014
.. autoclass:: TrainState
1115
:members:

docs_nnx/api_reference/flax.nnx/object.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@ object
44
.. automodule:: flax.nnx
55
.. currentmodule:: flax.nnx
66

7+
.. autoclass:: Pytree
8+
:members:
79
.. autoclass:: Object
810
:members:
911
.. autofunction:: data
1012
.. autodata:: Data
1113
:annotation:
12-
.. autofunction:: is_data_type
13-
.. autofunction:: register_data_type
14+
.. autofunction:: static
15+
.. autodata:: Static
16+
:annotation:
17+
.. autofunction:: is_data
18+
.. autofunction:: register_data_type
19+
.. autofunction:: check_pytree

docs_nnx/api_reference/flax.nnx/rnglib.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ rnglib
99
.. autoclass:: RngStream
1010
:members:
1111
.. autofunction:: split_rngs
12+
.. autofunction:: fork_rngs
1213
.. autofunction:: reseed

docs_nnx/guides/array_ref.ipynb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@
150150
"print(f\"{variable.has_ref = }\")"
151151
]
152152
},
153+
{
154+
"cell_type": "markdown",
155+
"id": "839332be",
156+
"metadata": {},
157+
"source": [
158+
"Mention `nnx.use_refs` can be used as global flag"
159+
]
160+
},
153161
{
154162
"cell_type": "markdown",
155163
"id": "1b2632f1",

docs_nnx/guides/array_ref.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ with nnx.use_refs(True):
5858
print(f"{variable.has_ref = }")
5959
```
6060

61+
Mention `nnx.use_refs` can be used as global flag
62+
63+
+++
64+
6165
### Changing Status
6266

6367
```{code-cell} ipython3

0 commit comments

Comments
 (0)