[QwixInterception] Selective disable_jit for XLA and ODML providers.#230
Merged
copybara-service[bot] merged 1 commit intomainfrom Apr 6, 2026
Merged
[QwixInterception] Selective disable_jit for XLA and ODML providers.#230copybara-service[bot] merged 1 commit intomainfrom
copybara-service[bot] merged 1 commit intomainfrom
Conversation
a565991 to
7524a00
Compare
disable_jit for XLA and ODML providers.1db619e to
2fe6e2e
Compare
This CL refines the interception logic to only use disable_jit when absolutely necessary, distinguishing between high-level function interception (XLA) and primitive-level interception (ODML). XLA Quantization (QtProvider, PtqProvider): - Doesn't disable_jit by default. - Intercepts high-level Functions and PjitFunctions (e.g., jnp.dot). - Qwix uses code-object patching to hook these functions during the initial JIT trace. - PjitFuncitons will use simple attribute patching because code object is bypassed in jitted environments. - Removing disable_jit for XLA models(large) significantly reduces compile times. ODML Quantization (ODML Providers): - Disable_jit by default. - Intercepts high-level Functions, PjitFunctions (e.g., jnp.dot), and low-level JAX primitives via Primitive.bind for fine-grained metadata propagation. - **Visibility**: When JIT is enabled, JAX dispatches primitives through a C++ path that bypasses Python-level bind overrides. disable_jit=True is required to force the Eager path where these primitive calls are visible to the Qwix interceptor. - **Recursion Avoidance**: For PjitFunctions like jnp.dot, the interceptor calls the globally patched version. disable_jit ensures the recursion-guarded 'original' call runs eagerly, preventing nested JIT traces that might lead to inconsistent metadata states. - **Name Stability**: Forcing Eager mode ensures consistent op ID generation (e.g., 'dot') between QAT and conversion, avoiding the shift to primitive names (e.g., 'dot_general') that would occur if the high-level Python wrappers were bypassed during JIT tracing. PiperOrigin-RevId: 895482861
2fe6e2e to
88ace68
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[QwixInterception] Selective disable_jit for XLA and ODML providers.
This CL refines the interception logic to only use disable_jit when absolutely necessary, distinguishing between high-level function interception (XLA) and primitive-level interception (ODML).
XLA Quantization (QtProvider, PtqProvider):
ODML Quantization (ODML Providers):