From e71a66a6b96ad908b07b2924d672558e43e6af76 Mon Sep 17 00:00:00 2001 From: swayaminsync Date: Fri, 17 Oct 2025 14:38:01 +0530 Subject: [PATCH] heaviside impl --- quaddtype/numpy_quaddtype/src/ops.hpp | 40 +++++++ .../numpy_quaddtype/src/umath/binary_ops.cpp | 3 + quaddtype/release_tracker.md | 2 +- quaddtype/tests/test_quaddtype.py | 102 ++++++++++++++++++ 4 files changed, 146 insertions(+), 1 deletion(-) diff --git a/quaddtype/numpy_quaddtype/src/ops.hpp b/quaddtype/numpy_quaddtype/src/ops.hpp index dd385d92..efa33f9b 100644 --- a/quaddtype/numpy_quaddtype/src/ops.hpp +++ b/quaddtype/numpy_quaddtype/src/ops.hpp @@ -752,6 +752,26 @@ quad_logaddexp2(const Sleef_quad *x, const Sleef_quad *y) return Sleef_addq1_u05(max_val, log2_term); } +static inline Sleef_quad +quad_heaviside(const Sleef_quad *x1, const Sleef_quad *x2) +{ + // heaviside(x1, x2) = 0 if x1 < 0, x2 if x1 == 0, 1 if x1 > 0 + // NaN propagation: only propagate NaN from x1, not from x2 (unless x1 == 0) + if (Sleef_iunordq1(*x1, *x1)) { + return *x1; // x1 is NaN, return NaN + } + + if (Sleef_icmpltq1(*x1, QUAD_ZERO)) { + return QUAD_ZERO; + } + else if (Sleef_icmpeqq1(*x1, QUAD_ZERO)) { + return *x2; // When x1 == 0, return x2 (even if x2 is NaN) + } + else { + return QUAD_ONE; + } +} + // Binary long double operations typedef long double (*binary_op_longdouble_def)(const long double *, const long double *); // Binary long double operations with 2 outputs (for divmod, modf, frexp) @@ -1002,6 +1022,26 @@ ld_logaddexp2(const long double *x, const long double *y) return max_val + log2l(1.0L + exp2l(-abs_diff)); } +static inline long double +ld_heaviside(const long double *x1, const long double *x2) +{ + // heaviside(x1, x2) = 0 if x1 < 0, x2 if x1 == 0, 1 if x1 > 0 + // NaN propagation: only propagate NaN from x1, not from x2 (unless x1 == 0) + if (isnan(*x1)) { + return *x1; // x1 is NaN, return NaN + } + + if (*x1 < 0.0L) { + return 0.0L; + } + else if (*x1 == 0.0L) { + return *x2; // When x1 == 0, return x2 (even if x2 is NaN) + } + else { + return 1.0L; + } +} + // comparison quad functions typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *); diff --git a/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp b/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp index 7775b706..b1417747 100644 --- a/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp +++ b/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp @@ -454,6 +454,9 @@ init_quad_binary_ops(PyObject *numpy) if (create_quad_binary_ufunc(numpy, "logaddexp2") < 0) { return -1; } + if (create_quad_binary_ufunc(numpy, "heaviside") < 0) { + return -1; + } if (create_quad_binary_2out_ufunc(numpy, "divmod") < 0) { return -1; } diff --git a/quaddtype/release_tracker.md b/quaddtype/release_tracker.md index cb61ecbe..29e6cebb 100644 --- a/quaddtype/release_tracker.md +++ b/quaddtype/release_tracker.md @@ -26,7 +26,7 @@ | fabs | ✅ | ✅ | | rint | ✅ | ✅ | | sign | ✅ | ✅ | -| heaviside | | | +| heaviside | ✅ | ✅ | | conj | | | | conjugate | | | | exp | ✅ | ✅ | diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index b63daeef..b1fbb0bc 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -1591,3 +1591,105 @@ def test_fabs(val): if float_result == 0.0: assert not np.signbit(quad_result), f"fabs({val}) should not have negative sign" assert not np.signbit(quad_arr_result[0]), f"fabs({val}) should not have negative sign" + + +@pytest.mark.parametrize("x1,x2", [ + # Basic cases: x1 < 0 -> 0 + ("-1.0", "0.5"), ("-5.0", "0.5"), ("-100.0", "0.5"), + ("-1e10", "0.5"), ("-0.1", "0.5"), + + # Basic cases: x1 == 0 -> x2 + ("0.0", "0.5"), ("0.0", "0.0"), ("0.0", "1.0"), + ("-0.0", "0.5"), ("-0.0", "0.0"), ("-0.0", "1.0"), + + # Basic cases: x1 > 0 -> 1 + ("1.0", "0.5"), ("5.0", "0.5"), ("100.0", "0.5"), + ("1e10", "0.5"), ("0.1", "0.5"), + + # Edge cases with different x2 values + ("0.0", "-1.0"), ("0.0", "2.0"), ("0.0", "100.0"), + + # Special values: infinity + ("inf", "0.5"), ("-inf", "0.5"), + ("inf", "0.0"), ("-inf", "0.0"), + + # Special values: NaN (should propagate) + ("nan", "0.5"), ("0.5", "nan"), ("nan", "nan"), + ("-nan", "0.5"), ("0.5", "-nan"), + + # Edge case: zero x1 with special x2 + ("0.0", "inf"), ("0.0", "-inf"), ("0.0", "nan"), + ("-0.0", "inf"), ("-0.0", "-inf"), ("-0.0", "nan"), +]) +def test_heaviside(x1, x2): + """ + Test np.heaviside ufunc for QuadPrecision dtype. + + heaviside(x1, x2) = 0 if x1 < 0 + x2 if x1 == 0 + 1 if x1 > 0 + + This is the Heaviside step function where x2 determines the value at x1=0. + """ + quad_x1 = QuadPrecision(x1) + quad_x2 = QuadPrecision(x2) + float_x1 = float(x1) + float_x2 = float(x2) + + # Test scalar inputs + quad_result = np.heaviside(quad_x1, quad_x2) + float_result = np.heaviside(float_x1, float_x2) + + # Test array inputs + quad_arr_x1 = np.array([quad_x1], dtype=QuadPrecDType()) + quad_arr_x2 = np.array([quad_x2], dtype=QuadPrecDType()) + quad_arr_result = np.heaviside(quad_arr_x1, quad_arr_x2) + + # Check results match + np.testing.assert_array_equal( + np.array(quad_result).astype(float), + float_result, + err_msg=f"Scalar heaviside({x1}, {x2}) mismatch" + ) + + np.testing.assert_array_equal( + quad_arr_result.astype(float)[0], + float_result, + err_msg=f"Array heaviside({x1}, {x2}) mismatch" + ) + + # Additional checks for non-NaN results + if not np.isnan(float_result): + # Verify the expected value based on x1 + if float_x1 < 0: + assert float(quad_result) == 0.0, f"Expected 0 for heaviside({x1}, {x2})" + elif float_x1 == 0.0: + np.testing.assert_array_equal( + float(quad_result), float_x2, + err_msg=f"Expected {x2} for heaviside(0, {x2})" + ) + else: # float_x1 > 0 + assert float(quad_result) == 1.0, f"Expected 1 for heaviside({x1}, {x2})" + + +def test_heaviside_broadcast(): + """Test that heaviside works with broadcasting""" + x1 = np.array([-1.0, 0.0, 1.0], dtype=QuadPrecDType()) + x2 = QuadPrecision("0.5") + + result = np.heaviside(x1, x2) + expected = np.array([0.0, 0.5, 1.0], dtype=np.float64) + + assert result.dtype.name == "QuadPrecDType128" + np.testing.assert_array_equal(result.astype(float), expected) + + # Test with array for both arguments + x1_arr = np.array([-2.0, -0.0, 0.0, 5.0], dtype=QuadPrecDType()) + x2_arr = np.array([0.5, 0.5, 1.0, 0.5], dtype=QuadPrecDType()) + + result = np.heaviside(x1_arr, x2_arr) + expected = np.array([0.0, 0.5, 1.0, 1.0], dtype=np.float64) + + assert result.dtype.name == "QuadPrecDType128" + np.testing.assert_array_equal(result.astype(float), expected) +