Skip to content

[QwixNnxDataType] Register Qwix metadata and array structures as NNX data types.#218

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_872561876
Feb 20, 2026
Merged

[QwixNnxDataType] Register Qwix metadata and array structures as NNX data types.#218
copybara-service[bot] merged 1 commit intomainfrom
test_872561876

Conversation

@copybara-service
Copy link
Copy Markdown

@copybara-service copybara-service Bot commented Feb 19, 2026

[QwixNnxDataType] Register Qwix metadata and array structures as NNX data types.

This change is required to maintain compatibility with the stricter trace-context
and graph validation logic introduced in Flax NNX.

Following that update, nnx.graph.flatten and other graph utilities now strictly
enforce the boundary between dynamic data and static metadata. Any attribute that is not an nnx.Variable, nnx.Module, or an explicitly registered data type
defaults to static. If such an attribute contains JAX arrays or tracers, NNX
raises a ValueError.

We register the following types as dynamic data nodes to ensure they are
correctly handled during graph traversal (nnx.split, nnx.state, nnx.update):

  • QArray (in qarray.py)
  • PaddedQArray (in padded_ptq.py)
  • WithAux (in ptq.py)
  • WithAwqScale (in awq.py)

@copybara-service copybara-service Bot changed the title [QwixNNX] Register Qwix metadata and array structures as NNX data types. [QwixNnxDataType] Register Qwix metadata and array structures as NNX data types. Feb 20, 2026
…data types.

This change is required to maintain compatibility with the stricter trace-context
and graph validation logic introduced in Flax NNX.

Following that update, `nnx.graph.flatten` and other graph utilities now strictly
enforce the boundary between dynamic `data` and `static` metadata. Any attribute that is not an `nnx.Variable`, `nnx.Module`, or an explicitly registered data type
defaults to `static`. If such an attribute contains JAX arrays or tracers, NNX
raises a ValueError.

We register the following types as dynamic data nodes to ensure they are
correctly handled during graph traversal (`nnx.split`, `nnx.state`, `nnx.update`):
- QArray (in qarray.py)
- PaddedQArray (in padded_ptq.py)
- WithAux (in ptq.py)
- WithAwqScale (in awq.py)

PiperOrigin-RevId: 872659797
@copybara-service copybara-service Bot merged commit e292088 into main Feb 20, 2026
@copybara-service copybara-service Bot deleted the test_872561876 branch February 20, 2026 02:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant