Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions quaddtype/numpy_quaddtype/src/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,26 @@ quad_logaddexp2(const Sleef_quad *x, const Sleef_quad *y)
return Sleef_addq1_u05(max_val, log2_term);
}

static inline Sleef_quad
quad_heaviside(const Sleef_quad *x1, const Sleef_quad *x2)
{
// heaviside(x1, x2) = 0 if x1 < 0, x2 if x1 == 0, 1 if x1 > 0
// NaN propagation: only propagate NaN from x1, not from x2 (unless x1 == 0)
if (Sleef_iunordq1(*x1, *x1)) {
return *x1; // x1 is NaN, return NaN
}

if (Sleef_icmpltq1(*x1, QUAD_ZERO)) {
return QUAD_ZERO;
}
else if (Sleef_icmpeqq1(*x1, QUAD_ZERO)) {
return *x2; // When x1 == 0, return x2 (even if x2 is NaN)
}
else {
return QUAD_ONE;
}
}

// Binary long double operations
typedef long double (*binary_op_longdouble_def)(const long double *, const long double *);
// Binary long double operations with 2 outputs (for divmod, modf, frexp)
Expand Down Expand Up @@ -1002,6 +1022,26 @@ ld_logaddexp2(const long double *x, const long double *y)
return max_val + log2l(1.0L + exp2l(-abs_diff));
}

static inline long double
ld_heaviside(const long double *x1, const long double *x2)
{
// heaviside(x1, x2) = 0 if x1 < 0, x2 if x1 == 0, 1 if x1 > 0
// NaN propagation: only propagate NaN from x1, not from x2 (unless x1 == 0)
if (isnan(*x1)) {
return *x1; // x1 is NaN, return NaN
}

if (*x1 < 0.0L) {
return 0.0L;
}
else if (*x1 == 0.0L) {
return *x2; // When x1 == 0, return x2 (even if x2 is NaN)
}
else {
return 1.0L;
}
}

// comparison quad functions
typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *);

Expand Down
3 changes: 3 additions & 0 deletions quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,9 @@ init_quad_binary_ops(PyObject *numpy)
if (create_quad_binary_ufunc<quad_logaddexp2, ld_logaddexp2>(numpy, "logaddexp2") < 0) {
return -1;
}
if (create_quad_binary_ufunc<quad_heaviside, ld_heaviside>(numpy, "heaviside") < 0) {
return -1;
}
if (create_quad_binary_2out_ufunc<quad_divmod, ld_divmod>(numpy, "divmod") < 0) {
return -1;
}
Expand Down
2 changes: 1 addition & 1 deletion quaddtype/release_tracker.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
| fabs | ✅ | ✅ |
| rint | ✅ | ✅ |
| sign | ✅ | ✅ |
| heaviside | | |
| heaviside | | ✅ |
| conj | | |
| conjugate | | |
| exp | ✅ | ✅ |
Expand Down
102 changes: 102 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,3 +1591,105 @@ def test_fabs(val):
if float_result == 0.0:
assert not np.signbit(quad_result), f"fabs({val}) should not have negative sign"
assert not np.signbit(quad_arr_result[0]), f"fabs({val}) should not have negative sign"


@pytest.mark.parametrize("x1,x2", [
# Basic cases: x1 < 0 -> 0
("-1.0", "0.5"), ("-5.0", "0.5"), ("-100.0", "0.5"),
("-1e10", "0.5"), ("-0.1", "0.5"),

# Basic cases: x1 == 0 -> x2
("0.0", "0.5"), ("0.0", "0.0"), ("0.0", "1.0"),
("-0.0", "0.5"), ("-0.0", "0.0"), ("-0.0", "1.0"),

# Basic cases: x1 > 0 -> 1
("1.0", "0.5"), ("5.0", "0.5"), ("100.0", "0.5"),
("1e10", "0.5"), ("0.1", "0.5"),

# Edge cases with different x2 values
("0.0", "-1.0"), ("0.0", "2.0"), ("0.0", "100.0"),

# Special values: infinity
("inf", "0.5"), ("-inf", "0.5"),
("inf", "0.0"), ("-inf", "0.0"),

# Special values: NaN (should propagate)
("nan", "0.5"), ("0.5", "nan"), ("nan", "nan"),
("-nan", "0.5"), ("0.5", "-nan"),

# Edge case: zero x1 with special x2
("0.0", "inf"), ("0.0", "-inf"), ("0.0", "nan"),
("-0.0", "inf"), ("-0.0", "-inf"), ("-0.0", "nan"),
])
def test_heaviside(x1, x2):
"""
Test np.heaviside ufunc for QuadPrecision dtype.

heaviside(x1, x2) = 0 if x1 < 0
x2 if x1 == 0
1 if x1 > 0

This is the Heaviside step function where x2 determines the value at x1=0.
"""
quad_x1 = QuadPrecision(x1)
quad_x2 = QuadPrecision(x2)
float_x1 = float(x1)
float_x2 = float(x2)

# Test scalar inputs
quad_result = np.heaviside(quad_x1, quad_x2)
float_result = np.heaviside(float_x1, float_x2)

# Test array inputs
quad_arr_x1 = np.array([quad_x1], dtype=QuadPrecDType())
quad_arr_x2 = np.array([quad_x2], dtype=QuadPrecDType())
quad_arr_result = np.heaviside(quad_arr_x1, quad_arr_x2)

# Check results match
np.testing.assert_array_equal(
np.array(quad_result).astype(float),
float_result,
err_msg=f"Scalar heaviside({x1}, {x2}) mismatch"
)

np.testing.assert_array_equal(
quad_arr_result.astype(float)[0],
float_result,
err_msg=f"Array heaviside({x1}, {x2}) mismatch"
)

# Additional checks for non-NaN results
if not np.isnan(float_result):
# Verify the expected value based on x1
if float_x1 < 0:
assert float(quad_result) == 0.0, f"Expected 0 for heaviside({x1}, {x2})"
elif float_x1 == 0.0:
np.testing.assert_array_equal(
float(quad_result), float_x2,
err_msg=f"Expected {x2} for heaviside(0, {x2})"
)
else: # float_x1 > 0
assert float(quad_result) == 1.0, f"Expected 1 for heaviside({x1}, {x2})"


def test_heaviside_broadcast():
"""Test that heaviside works with broadcasting"""
x1 = np.array([-1.0, 0.0, 1.0], dtype=QuadPrecDType())
x2 = QuadPrecision("0.5")

result = np.heaviside(x1, x2)
expected = np.array([0.0, 0.5, 1.0], dtype=np.float64)

assert result.dtype.name == "QuadPrecDType128"
np.testing.assert_array_equal(result.astype(float), expected)

# Test with array for both arguments
x1_arr = np.array([-2.0, -0.0, 0.0, 5.0], dtype=QuadPrecDType())
x2_arr = np.array([0.5, 0.5, 1.0, 0.5], dtype=QuadPrecDType())

result = np.heaviside(x1_arr, x2_arr)
expected = np.array([0.0, 0.5, 1.0, 1.0], dtype=np.float64)

assert result.dtype.name == "QuadPrecDType128"
np.testing.assert_array_equal(result.astype(float), expected)

Loading