Skip to content

[QwixStructuralInterception] Enable structural binding for PrimitiveBindOp to support transpose op removal and allow for robust ODML metadata propagation.#211

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_867809631
Feb 20, 2026

Conversation

@copybara-service
Copy link
Copy Markdown

@copybara-service copybara-service Bot commented Feb 10, 2026

[QwixStructuralInterception] Enable structural binding for PrimitiveBindOp to support transpose op removal and allow for robust ODML metadata propagation.

This CL refactors the interception logic to support multiple interceptors (e.g., structural + numerical) and disables JIT for primitive binding. This enables catching all transpose operations at the primitive level, eliminating the maintenance burden of manually mapping every transpose op variant.

Key changes:

  • interception.py:
    Added support for interceptor stacking.
    Replaced hardcoded JIT disabling with a generic PRIMITIVE_BIND_KEY check.
  • model.py:
    Unified Linen/NNX interception in _apply_interceptors, applying structural then numerical interceptors.
  • odml.py & odml_ops.py:
    Configured ODML providers to return multiple interceptors, including PrimitiveBindOp.

@copybara-service copybara-service Bot force-pushed the test_867809631 branch 4 times, most recently from cc4b174 to b3c4eb8 Compare February 13, 2026 06:06
@copybara-service copybara-service Bot changed the title [QwixOdmlBind] Implement robust metadata propagation for Qwix ODML by intercepting jax.core.Primitive.bind. [QwixStructuralInterception] Enable structural binding for PrimitiveBindOp to support transpose op removal and allow for robust ODML metadata propagation. Feb 13, 2026
@copybara-service copybara-service Bot force-pushed the test_867809631 branch 3 times, most recently from 4100c6a to 899157f Compare February 20, 2026 06:32
…indOp to support transpose op removal and allow for robust ODML metadata propagation.

This CL refactors the interception logic to support multiple interceptors (e.g., structural + numerical) and disables JIT for primitive binding. This enables catching all transpose operations at the primitive level, eliminating the maintenance burden of manually mapping every transpose op variant.

Key changes:
- interception.py:
Added support for interceptor stacking.
Replaced hardcoded JIT disabling with a generic PRIMITIVE_BIND_KEY check.
- model.py:
Unified Linen/NNX interception in _apply_interceptors, applying structural then numerical interceptors.
- odml.py & odml_ops.py:
Configured ODML providers to return multiple interceptors, including PrimitiveBindOp.

PiperOrigin-RevId: 872728854
@copybara-service copybara-service Bot merged commit 6b0fa93 into main Feb 20, 2026
@copybara-service copybara-service Bot deleted the test_867809631 branch February 20, 2026 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant