Skip to content

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Sep 16, 2024

What does this PR do?

Adds the Transforms guide.

Preview

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Base automatically changed from nnx-transforms-guide to main September 16, 2024 20:32
@cgarciae cgarciae force-pushed the nnx-real-transforms-guide branch 7 times, most recently from f0d5f35 to f46c9ea Compare September 19, 2024 08:55
Copy link
Collaborator

@IvyZX IvyZX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this guide! Super helpful and cool. Just a few nits on wordings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX models inputs to transformations as trees, Flax NNX models inputs as graphs to allow for sharing references.

A bit hard to read - maybe:
JAX transformations see inputs as trees of arrays, and Flax NNX see inputs as graphs of Python references.

However, to express most of Python's object model Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local (updates to globals inside transforms are not supported).

This line also a bit verbose? Maybe just:
Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the input graph (updates to globals inside transforms are not supported).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not here, but I was hoping to see an example of transforming and using an nnx.Module method to showcase that it works and can be a natural pattern for users to take, since most transforms happen not at top level but in-between two layer definitions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a good point. I'll add a variation of the first example using vmap over __call__ so users know that its possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

For example, jit expects the structure of the inputs to be stable in order to cache the compiled function, changing the graph structure inside a nnx.jit-ed function cause continuous recompilations and performance degradation, scan on the other hand only allows a fixed carry structure, so adding/removing substates declared as carry will cause an error.

For example, jit expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a nnx.jit-ed function cause continuous recompilations and performance degradation. On the other hand, scan only allows a fixed carry structure, so adding/removing substates declared as carry will cause an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to only call vmap once when only one call is needed, to avoid confusion. Same for the example below.

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
@nnx.vmap(in_axes=(state_axes, 0), out_axes=1)
def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

y = stateful_vector_dot(weights, x)
y = stateful_vector_dot(weights, x)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice Modules usually keep that need random state simply need a references to a Rngs object that is passed to them during initialization, and use it to generate a unique key for each random operation.

What about:
In practice Modules simply need to keep a reference to a Rngs object that is passed to them during initialization, and use it to generate a unique key for each random operation.

@cgarciae cgarciae force-pushed the nnx-real-transforms-guide branch from f46c9ea to 59acf38 Compare September 20, 2024 21:47
@cgarciae cgarciae force-pushed the nnx-real-transforms-guide branch from 59acf38 to fb1a9cc Compare September 22, 2024 05:54
@cgarciae
Copy link
Collaborator Author

Thanks @IvyZX for the detailed feedback. I've integrated all the suggestions.

@copybara-service copybara-service bot merged commit b2277ab into main Sep 23, 2024
19 checks passed
@copybara-service copybara-service bot deleted the nnx-real-transforms-guide branch September 23, 2024 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants