Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions quaddtype/numpy_quaddtype/src/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,48 @@ quad_copysign(const Sleef_quad *in1, const Sleef_quad *in2)
return Sleef_copysignq1(*in1, *in2);
}

static inline Sleef_quad
quad_logaddexp(const Sleef_quad *x, const Sleef_quad *y)
{
// logaddexp(x, y) = log(exp(x) + exp(y))
// Numerically stable implementation: max(x, y) + log1p(exp(-abs(x - y)))

// Handle NaN
if (Sleef_iunordq1(*x, *y)) {
return Sleef_iunordq1(*x, *x) ? *x : *y;
}

// Handle infinities
// If both are -inf, result is -inf
Sleef_quad neg_inf = Sleef_negq1(QUAD_POS_INF);
if (Sleef_icmpeqq1(*x, neg_inf) && Sleef_icmpeqq1(*y, neg_inf)) {
return neg_inf;
}

// If either is +inf, result is +inf
if (Sleef_icmpeqq1(*x, QUAD_POS_INF) || Sleef_icmpeqq1(*y, QUAD_POS_INF)) {
return QUAD_POS_INF;
}

// If one is -inf, result is the other value
if (Sleef_icmpeqq1(*x, neg_inf)) {
return *y;
}
if (Sleef_icmpeqq1(*y, neg_inf)) {
return *x;
}

// Numerically stable computation
Sleef_quad diff = Sleef_subq1_u05(*x, *y);
Sleef_quad abs_diff = Sleef_fabsq1(diff);
Sleef_quad neg_abs_diff = Sleef_negq1(abs_diff);
Sleef_quad exp_term = Sleef_expq1_u10(neg_abs_diff);
Sleef_quad log1p_term = Sleef_log1pq1_u10(exp_term);

Sleef_quad max_val = Sleef_icmpgtq1(*x, *y) ? *x : *y;
return Sleef_addq1_u05(max_val, log1p_term);
}

// Binary long double operations
typedef long double (*binary_op_longdouble_def)(const long double *, const long double *);

Expand Down Expand Up @@ -680,6 +722,43 @@ ld_copysign(const long double *in1, const long double *in2)
return copysignl(*in1, *in2);
}

static inline long double
ld_logaddexp(const long double *x, const long double *y)
{
// logaddexp(x, y) = log(exp(x) + exp(y))
// Numerically stable implementation: max(x, y) + log1p(exp(-abs(x - y)))

// Handle NaN
if (isnan(*x) || isnan(*y)) {
return isnan(*x) ? *x : *y;
}

// Handle infinities
// If both are -inf, result is -inf
if (isinf(*x) && *x < 0 && isinf(*y) && *y < 0) {
return -INFINITY;
}

// If either is +inf, result is +inf
if ((isinf(*x) && *x > 0) || (isinf(*y) && *y > 0)) {
return INFINITY;
}

// If one is -inf, result is the other value
if (isinf(*x) && *x < 0) {
return *y;
}
if (isinf(*y) && *y < 0) {
return *x;
}

// Numerically stable computation
long double diff = *x - *y;
long double abs_diff = fabsl(diff);
long double max_val = (*x > *y) ? *x : *y;
return max_val + log1pl(expl(-abs_diff));
}

// comparison quad functions
typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *);

Expand Down
3 changes: 3 additions & 0 deletions quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,8 @@ init_quad_binary_ops(PyObject *numpy)
if (create_quad_binary_ufunc<quad_copysign, ld_copysign>(numpy, "copysign") < 0) {
return -1;
}
if (create_quad_binary_ufunc<quad_logaddexp, ld_logaddexp>(numpy, "logaddexp") < 0) {
return -1;
}
return 0;
}
2 changes: 1 addition & 1 deletion quaddtype/release_tracker.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
| multiply | ✅ | ✅ |
| matmul | ✅ | ✅ |
| divide | ✅ | ✅ |
| logaddexp | | |
| logaddexp | | ✅ |
| logaddexp2 | | |
| true_divide | | |
| floor_divide | | |
Expand Down
103 changes: 102 additions & 1 deletion quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def test_logarithmic_functions(op, val):
# Check sign for zero results
if float_result == 0.0:
assert np.signbit(float_result) == np.signbit(
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
quad_result), f"Zero sign mismatch"


@pytest.mark.parametrize("val", [
Expand All @@ -390,6 +390,7 @@ def test_logarithmic_functions(op, val):
])
def test_log1p(val):
"""Comprehensive test for log1p function"""
op = "log1p"
quad_val = QuadPrecision(val)
float_val = float(val)

Expand Down Expand Up @@ -427,6 +428,106 @@ def test_log1p(val):
assert np.signbit(float_result) == np.signbit(
quad_result), f"Zero sign mismatch for {op}({val})"


@pytest.mark.parametrize("x", [
# Regular values
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
# Large values (test numerical stability)
"100.0", "1000.0", "-100.0", "-1000.0",
# Small values
"1e-10", "-1e-10", "1e-20", "-1e-20",
# Special values
"inf", "-inf", "nan", "-nan", "-0.0"
])
@pytest.mark.parametrize("y", [
# Regular values
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
# Large values
"100.0", "1000.0", "-100.0", "-1000.0",
# Small values
"1e-10", "-1e-10", "1e-20", "-1e-20",
# Special values
"inf", "-inf", "nan", "-nan", "-0.0"
])
def test_logaddexp(x, y):
"""Comprehensive test for logaddexp function: log(exp(x) + exp(y))"""
quad_x = QuadPrecision(x)
quad_y = QuadPrecision(y)
float_x = float(x)
float_y = float(y)

quad_result = np.logaddexp(quad_x, quad_y)
float_result = np.logaddexp(float_x, float_y)

# Handle NaN cases
if np.isnan(float_result):
assert np.isnan(float(quad_result)), \
f"Expected NaN for logaddexp({x}, {y}), got {float(quad_result)}"
return

# Handle infinity cases
if np.isinf(float_result):
assert np.isinf(float(quad_result)), \
f"Expected inf for logaddexp({x}, {y}), got {float(quad_result)}"
if not np.isnan(float_result):
assert np.sign(float_result) == np.sign(float(quad_result)), \
f"Infinity sign mismatch for logaddexp({x}, {y})"
return

# For finite results, check with appropriate tolerance
# logaddexp is numerically sensitive, especially for large differences
if abs(float_x - float_y) > 50:
# When values differ greatly, result should be close to max(x, y)
rtol = 1e-10
atol = 1e-10
else:
rtol = 1e-13
atol = 1e-15

np.testing.assert_allclose(
float(quad_result), float_result,
rtol=rtol, atol=atol,
err_msg=f"Value mismatch for logaddexp({x}, {y})"
)


def test_logaddexp_special_properties():
"""Test special mathematical properties of logaddexp"""
# logaddexp(x, x) = x + log(2)
x = QuadPrecision("2.0")
result = np.logaddexp(x, x)
expected = float(x) + np.log(2.0)
np.testing.assert_allclose(float(result), expected, rtol=1e-14)

# logaddexp(x, -inf) = x
x = QuadPrecision("5.0")
result = np.logaddexp(x, QuadPrecision("-inf"))
np.testing.assert_allclose(float(result), float(x), rtol=1e-14)

# logaddexp(-inf, x) = x
result = np.logaddexp(QuadPrecision("-inf"), x)
np.testing.assert_allclose(float(result), float(x), rtol=1e-14)

# logaddexp(-inf, -inf) = -inf
result = np.logaddexp(QuadPrecision("-inf"), QuadPrecision("-inf"))
assert np.isinf(float(result)) and float(result) < 0

# logaddexp(inf, anything) = inf
result = np.logaddexp(QuadPrecision("inf"), QuadPrecision("100.0"))
assert np.isinf(float(result)) and float(result) > 0

# logaddexp(anything, inf) = inf
result = np.logaddexp(QuadPrecision("100.0"), QuadPrecision("inf"))
assert np.isinf(float(result)) and float(result) > 0

# Commutativity: logaddexp(x, y) = logaddexp(y, x)
x = QuadPrecision("3.0")
y = QuadPrecision("5.0")
result1 = np.logaddexp(x, y)
result2 = np.logaddexp(y, x)
np.testing.assert_allclose(float(result1), float(result2), rtol=1e-14)


def test_inf():
assert QuadPrecision("inf") > QuadPrecision("1e1000")
assert np.signbit(QuadPrecision("inf")) == 0
Expand Down
Loading