Skip to content

[JAX] Fix reversible embedding#2496

Merged
bkowalskiINTEL merged 5 commits into
mainfrom
dev/bkowalsk/jax_fix_reversible_embedding
Jun 29, 2026
Merged

[JAX] Fix reversible embedding#2496
bkowalskiINTEL merged 5 commits into
mainfrom
dev/bkowalsk/jax_fix_reversible_embedding

Conversation

@bkowalskiINTEL

Copy link
Copy Markdown
Contributor

This PR adds missing registrations for ReversibleEmbedding, which caused this layer not to be quantized.
This also revealed an issue with weights loading for QStaticReversibleEmbedding, which this PR also fixes.

Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses JAX/Keras quantization gaps for ReversibleEmbedding by ensuring the layer is correctly discovered/registered for quantization across Keras/KerasHub namespaces, and adjusts the quantized reversible embedding implementation to align with the updated base-class references (including load/call path behavior).

Changes:

  • Register ReversibleEmbedding quantized wrappers for both keras.layers.* and keras_hub.layers.* entry points (static and dynamic).
  • Standardize layer references to keras.layers.* / keras_hub.layers.* module-qualified names instead of direct imports.
  • Update super(...) calls in quantized reversible embedding to match the new base-class qualification.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
neural_compressor/jax/quantization/layers_static.py Adds dual registration for static ReversibleEmbedding and updates base-class references / call path.
neural_compressor/jax/quantization/layers_dynamic.py Adds dual registration for dynamic ReversibleEmbedding and updates base-class references / call path.

Comment thread neural_compressor/jax/quantization/layers_static.py
Comment thread neural_compressor/jax/quantization/layers_static.py
@chensuyue chensuyue added this to the 3.9 milestone Jun 17, 2026
@anko-intel anko-intel removed this from the 3.9 milestone Jun 18, 2026
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>

@anko-intel anko-intel left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bkowalskiINTEL bkowalskiINTEL merged commit debb577 into main Jun 29, 2026
14 checks passed
@bkowalskiINTEL bkowalskiINTEL deleted the dev/bkowalsk/jax_fix_reversible_embedding branch June 29, 2026 11:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants