Skip to content

[QwixGptqEinsum] Add einsum to GptqCalibrationProvider and support attention layer weight GPTQ.#201

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_861481456
Jan 29, 2026
Merged

[QwixGptqEinsum] Add einsum to GptqCalibrationProvider and support attention layer weight GPTQ.#201
copybara-service[bot] merged 1 commit intomainfrom
test_861481456

Conversation

@copybara-service
Copy link
Copy Markdown

[QwixGptqEinsum] Add einsum to GptqCalibrationProvider and support attention layer weight GPTQ.

The new einsum method intercepts jax.numpy.einsum ops, which are usually used in the attention layer(whereas the jax.lax.dot_general ops are usually used in the dense layer). This allows Qwix users to apply GPTQ to attention layer weights in addition to dense layer weights.

When GPTQ is enabled for the current operation and there are exactly two operands, it injects the GPTQ interceptor's dot_general method as the _dot_general argument to jax.numpy.einsum. This ensures that the underlying dot products within the einsum are processed by the GPTQ-aware dot_general.

This change improves the robustness of GptqCalibrationProvider by returning the original unquantized result instead of raising errors when encountering unsupported dimension configurations or operations without identifiable model parameters. This allows calibration to proceed successfully on models with a mix of supported and unsupported layers.

…t attention layer weight GPTQ.

The new einsum method intercepts `jax.numpy.einsum` ops, which are usually used in the attention layer(whereas the `jax.lax.dot_general` ops are usually used in the dense layer). This allows Qwix users to apply GPTQ to attention layer weights in addition to dense layer weights.

When GPTQ is enabled for the current operation and there are exactly two operands, it injects the GPTQ interceptor's dot_general method as the _dot_general argument to jax.numpy.einsum. This ensures that the underlying dot products within the einsum are processed by the GPTQ-aware dot_general.

This change improves the robustness of `GptqCalibrationProvider` by returning the original unquantized result instead of raising errors when encountering unsupported dimension configurations or operations without identifiable model parameters. This allows calibration to proceed successfully on models with a mix of supported and unsupported layers.

PiperOrigin-RevId: 862474667
@copybara-service copybara-service Bot merged commit da6d5c9 into main Jan 29, 2026
@copybara-service copybara-service Bot deleted the test_861481456 branch January 29, 2026 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant