Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qwix/_src/providers/odml.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def process_model_inputs(

def process_model_output(self, method_name: str, model_output: Any) -> Any:
"""Quantize the output of the model."""
self._initial_run_complete = True
if method_name == '__call__':
method_name = 'final' # backwards compatibility.
# Quantize the model output if needed.
Expand Down
27 changes: 27 additions & 0 deletions qwix/_src/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def __init__(
rules: The quantization rules in the order of precedence.
disable_jit: Whether to disable JIT when wrapping methods.
"""
self._rule_matches = [0] * len(rules)
self._rules = [self._init_rule(rule) for rule in rules]
self._logged_ops = set()
self._initial_run_complete = False
self.disable_jit = disable_jit

def _init_rule(self, rule: QuantizationRule) -> QuantizationRule:
Expand Down Expand Up @@ -176,6 +178,7 @@ def process_model_inputs(
def process_model_output(self, method_name: str, model_output: Any) -> Any:
"""Process the model output before it is returned."""
del method_name
self._initial_run_complete = True
return model_output

def _get_current_rule_and_op_id(
Expand Down Expand Up @@ -208,6 +211,8 @@ def _get_current_rule_and_op_id(
rule_idx = idx
break
rule = self._rules[rule_idx] if rule_idx is not None else None
if rule_idx is not None:
self._rule_matches[rule_idx] += 1
if only_rule:
return rule, None

Expand All @@ -228,3 +233,25 @@ def _get_current_rule_and_op_id(
'[QWIX] module=%r op=%s rule=%s', module_path, op_id, rule_idx
)
return rule, op_id

def get_unused_rules(self) -> Sequence[QuantizationRule]:
"""Returns the quantization rules that did not match any operations.

This should be called after model quantization (e.g., `quantize_model`) to
verify that all rules were applied as expected. A rule is considered unused
if its `module_path` regex did not match any module's path, or if its
`op_names` did not match any intercepted operation within a matching module.

Returns:
A sequence of unused quantization rules.
"""
if not self._initial_run_complete:
raise ValueError(
'Quantization is not completed yet. Please call `quantize_model`'
' before calling `get_unused_rules`.'
)
return [
self._rules[i]
for i, rule_matches in enumerate(self._rule_matches)
if rule_matches == 0
]
1 change: 1 addition & 0 deletions tests/_src/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_intercept_map(self) -> Mapping[str, Callable[..., Any]]:
return self._intercept_map

def process_model_output(self, method_name: str, model_output: Any) -> Any:
self._initial_run_complete = True
return model_output + 100


Expand Down
117 changes: 117 additions & 0 deletions tests/_src/qconfig_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from flax import nnx
from jax import numpy as jnp
from qwix._src import model as qwix_model
from qwix._src import qconfig
from qwix._src.core import qarray
from qwix._src.providers import ptq


class QconfigTest(absltest.TestCase):

def setUp(self):
super().setUp()
dim: int = 16

class MyModel(nnx.Module):

def __init__(self, rngs: nnx.Rngs):
self.lin1 = nnx.Linear(dim, dim, rngs=rngs)
self.lin2 = nnx.Linear(dim, dim, rngs=rngs)
self.layers = nnx.List(
[nnx.Linear(dim, dim, rngs=rngs) for _ in range(2)]
)

def __call__(self, x):
return self.lin1(x) + self.lin2(x) + sum(l(x) for l in self.layers)

self.model = MyModel(rngs=nnx.Rngs(0))
self.x = jnp.ones((1, dim))

def test_all_rules_used(self):
rules = [
qconfig.QuantizationRule(
weight_qtype="float8_e4m3fn",
act_qtype="float8_e4m3fn",
act_static_scale=False,
),
]
provider = ptq.PtqProvider(rules)
quant_model = qwix_model.quantize_model(self.model, provider, self.x)

# Check unused rules.
self.assertEmpty(provider.get_unused_rules())

# Check that all layers are quantized.
self.assertIsInstance(quant_model.lin1.kernel.array, qarray.QArray)
self.assertIsInstance(quant_model.lin2.kernel.array, qarray.QArray)
self.assertIsInstance(quant_model.layers[0].kernel.array, qarray.QArray)
self.assertIsInstance(quant_model.layers[1].kernel.array, qarray.QArray)

def test_some_rules_unused(self):
rules = [
qconfig.QuantizationRule(
module_path=r"layers/\d+",
weight_qtype="float8_e4m3fn",
act_qtype="float8_e4m3fn",
act_static_scale=False,
),
qconfig.QuantizationRule(
module_path=r"LIN\d+", # Typo in module path.
weight_qtype="float8_e4m3fn",
act_qtype="float8_e4m3fn",
act_static_scale=False,
),
]
provider = ptq.PtqProvider(rules)
quant_model = qwix_model.quantize_model(self.model, provider, self.x)
unused_rules = provider.get_unused_rules()

# Check unused rules.
self.assertLen(unused_rules, 1)
self.assertEqual(unused_rules[0].module_path, rules[1].module_path)

# Check that lin1 and lin2 are not quantized.
self.assertFalse(hasattr(quant_model.lin1.kernel, "array"))
self.assertFalse(hasattr(quant_model.lin2.kernel, "array"))

# Check that layers are quantized.
self.assertIsInstance(quant_model.layers[0].kernel.array, qarray.QArray)
self.assertIsInstance(quant_model.layers[1].kernel.array, qarray.QArray)

def test_get_unused_rules_before_quantize_model(self):
rules = [
qconfig.QuantizationRule(
module_path=r"layers/\d+",
weight_qtype="float8_e4m3fn",
act_qtype="float8_e4m3fn",
act_static_scale=False,
),
]
provider = ptq.PtqProvider(rules)
with self.assertRaisesRegex(
ValueError,
"Quantization is not completed yet. Please call `quantize_model`"
" before calling `get_unused_rules`.",
):
provider.get_unused_rules()

qwix_model.quantize_model(self.model, provider, self.x)
self.assertEmpty(provider.get_unused_rules())


if __name__ == "__main__":
absltest.main()