diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py index 633d494e..c0f89cde 100644 --- a/tests/custom_op/test_floatquant.py +++ b/tests/custom_op/test_floatquant.py @@ -27,8 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import pytest - import io import mock import numpy as np @@ -112,7 +110,7 @@ def brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, e max_available_float=max_val, saturating=True, ) - float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.0) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: torch.Tensor([1.0])) float_quant = BrevitasFloatQuant( bit_width=bit_width, float_scaling_impl=float_scaling_impl, @@ -127,33 +125,50 @@ def brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, e return expected_out -@pytest.mark.xfail(reason="Possible Brevitas version issue, needs investigation") -@given( - x=arrays( - dtype=np.float64, - shape=100, - elements=st.floats( - allow_nan=False, - allow_infinity=False, - allow_subnormal=True, - width=64, # Use 64-bit floats - ), - unique=True, - ), - exponent_bit_width=st.integers(1, 8), - mantissa_bit_width=st.integers(1, 8), - sign=st.booleans(), -) +# @pytest.mark.xfail(reason="Possible Brevitas version issue, needs investigation") +@st.composite +def inputs(draw): + # pick the torch dtype first + float_type = draw(st.sampled_from([np.float32, np.float64])) + + # build x with a matching numpy dtype + float width + x = draw( + arrays( + dtype=float_type, + shape=100, + elements=st.floats( + allow_nan=False, + allow_infinity=False, + allow_subnormal=True, + width=np.dtype(float_type).itemsize * 8, + ), + unique=True, + ) + ) + + exponent_bit_width = draw(st.integers(1, 8)) + mantissa_bit_width = draw(st.integers(1, 8)) + sign = draw(st.booleans()) + + return x, exponent_bit_width, mantissa_bit_width, sign, float_type + + +@given(data=inputs()) @settings( - max_examples=1000, verbosity=Verbosity.verbose, suppress_health_check=list(HealthCheck) + max_examples=1000, + verbosity=Verbosity.verbose, + suppress_health_check=list(HealthCheck), ) # Adjust the number of examples as needed -def test_brevitas_vs_qonnx(x, exponent_bit_width, mantissa_bit_width, sign): +def test_brevitas_vs_qonnx(data): + x, exponent_bit_width, mantissa_bit_width, sign, _ = data + x = torch.tensor(x) bit_width = exponent_bit_width + mantissa_bit_width + int(sign) assume(bit_width <= 8 and bit_width >= 4) scale = 1.0 exponent_bias = compute_default_exponent_bias(exponent_bit_width) max_val = compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias) - xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val).numpy() - xq = qonnx_float_quant(x, scale, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val) + xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, + exponent_bias, sign, max_val).numpy() + xq = qonnx_float_quant(x.numpy(), scale, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val) np.testing.assert_array_equal(xq, xq_t)