Skip to content

[QwixInterception] Selective disable_jit for XLA and ODML providers.#230

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_873910133
Apr 6, 2026
Merged

[QwixInterception] Selective disable_jit for XLA and ODML providers.#230
copybara-service[bot] merged 1 commit intomainfrom
test_873910133

Conversation

@copybara-service
Copy link
Copy Markdown

@copybara-service copybara-service Bot commented Mar 24, 2026

[QwixInterception] Selective disable_jit for XLA and ODML providers.

This CL refines the interception logic to only use disable_jit when absolutely necessary, distinguishing between high-level function interception (XLA) and primitive-level interception (ODML).

XLA Quantization (QtProvider, PtqProvider):

  • Doesn't disable_jit by default.
  • Intercepts high-level Functions and PjitFunctions (e.g., jnp.dot).
  • Qwix uses code-object patching to hook these functions during the initial JIT trace.
  • PjitFuncitons will use simple attribute patching because code object is bypassed in jitted environments.
  • Removing disable_jit for XLA models(large) significantly reduces compile times.

ODML Quantization (ODML Providers):

  • Disable_jit by default.
  • Intercepts high-level Functions, PjitFunctions (e.g., jnp.dot), and low-level JAX primitives via Primitive.bind for fine-grained metadata propagation.
  • Visibility: When JIT is enabled, JAX dispatches primitives through a C++ path that bypasses Python-level bind overrides. disable_jit=True is required to force the Eager path where these primitive calls are visible to the Qwix interceptor.
  • Recursion Avoidance: For PjitFunctions like jnp.dot, the interceptor calls the globally patched version. disable_jit ensures the recursion-guarded 'original' call runs eagerly, preventing nested JIT traces that might lead to inconsistent metadata states.
  • Name Stability: Forcing Eager mode ensures consistent op ID generation (e.g., 'dot') between QAT and conversion, avoiding the shift to primitive names (e.g., 'dot_general') that would occur if the high-level Python wrappers were bypassed during JIT tracing.

@copybara-service copybara-service Bot force-pushed the test_873910133 branch 2 times, most recently from a565991 to 7524a00 Compare April 3, 2026 00:03
@copybara-service copybara-service Bot changed the title [QwixInterception] Selective disable_jit for XLA and ODML providers. [QwixInterception] Selective disable_jit for XLA and ODML providers. Apr 3, 2026
@copybara-service copybara-service Bot force-pushed the test_873910133 branch 3 times, most recently from 1db619e to 2fe6e2e Compare April 6, 2026 20:41
This CL refines the interception logic to only use disable_jit when absolutely necessary, distinguishing between high-level function interception (XLA) and primitive-level interception (ODML).

XLA Quantization (QtProvider, PtqProvider):
- Doesn't disable_jit by default.
- Intercepts high-level Functions and PjitFunctions (e.g., jnp.dot).
- Qwix uses code-object patching to hook these functions during the initial JIT trace.
- PjitFuncitons will use simple attribute patching because code object is bypassed in jitted environments.
- Removing disable_jit for XLA models(large) significantly reduces compile times.

ODML Quantization (ODML Providers):
- Disable_jit by default.
- Intercepts high-level Functions, PjitFunctions (e.g., jnp.dot), and low-level JAX primitives via Primitive.bind for fine-grained metadata propagation.
- **Visibility**: When JIT is enabled, JAX dispatches primitives through a C++ path that bypasses Python-level bind overrides. disable_jit=True is required to force the Eager path where these primitive calls are visible to the Qwix interceptor.
- **Recursion Avoidance**: For PjitFunctions like jnp.dot, the interceptor calls the globally patched version. disable_jit ensures the recursion-guarded 'original' call runs eagerly, preventing nested JIT traces that might lead to inconsistent metadata states.
- **Name Stability**: Forcing Eager mode ensures consistent op ID generation (e.g., 'dot') between QAT and conversion, avoiding the shift to primitive names (e.g., 'dot_general') that would occur if the high-level Python wrappers were bypassed during JIT tracing.

PiperOrigin-RevId: 895482861
@copybara-service copybara-service Bot merged commit 88ace68 into main Apr 6, 2026
@copybara-service copybara-service Bot deleted the test_873910133 branch April 6, 2026 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant