diff --git a/quaddtype/numpy_quaddtype/src/ops.hpp b/quaddtype/numpy_quaddtype/src/ops.hpp index fa351d2..9fbf4b4 100644 --- a/quaddtype/numpy_quaddtype/src/ops.hpp +++ b/quaddtype/numpy_quaddtype/src/ops.hpp @@ -584,6 +584,48 @@ quad_copysign(const Sleef_quad *in1, const Sleef_quad *in2) return Sleef_copysignq1(*in1, *in2); } +static inline Sleef_quad +quad_logaddexp(const Sleef_quad *x, const Sleef_quad *y) +{ + // logaddexp(x, y) = log(exp(x) + exp(y)) + // Numerically stable implementation: max(x, y) + log1p(exp(-abs(x - y))) + + // Handle NaN + if (Sleef_iunordq1(*x, *y)) { + return Sleef_iunordq1(*x, *x) ? *x : *y; + } + + // Handle infinities + // If both are -inf, result is -inf + Sleef_quad neg_inf = Sleef_negq1(QUAD_POS_INF); + if (Sleef_icmpeqq1(*x, neg_inf) && Sleef_icmpeqq1(*y, neg_inf)) { + return neg_inf; + } + + // If either is +inf, result is +inf + if (Sleef_icmpeqq1(*x, QUAD_POS_INF) || Sleef_icmpeqq1(*y, QUAD_POS_INF)) { + return QUAD_POS_INF; + } + + // If one is -inf, result is the other value + if (Sleef_icmpeqq1(*x, neg_inf)) { + return *y; + } + if (Sleef_icmpeqq1(*y, neg_inf)) { + return *x; + } + + // Numerically stable computation + Sleef_quad diff = Sleef_subq1_u05(*x, *y); + Sleef_quad abs_diff = Sleef_fabsq1(diff); + Sleef_quad neg_abs_diff = Sleef_negq1(abs_diff); + Sleef_quad exp_term = Sleef_expq1_u10(neg_abs_diff); + Sleef_quad log1p_term = Sleef_log1pq1_u10(exp_term); + + Sleef_quad max_val = Sleef_icmpgtq1(*x, *y) ? *x : *y; + return Sleef_addq1_u05(max_val, log1p_term); +} + // Binary long double operations typedef long double (*binary_op_longdouble_def)(const long double *, const long double *); @@ -680,6 +722,43 @@ ld_copysign(const long double *in1, const long double *in2) return copysignl(*in1, *in2); } +static inline long double +ld_logaddexp(const long double *x, const long double *y) +{ + // logaddexp(x, y) = log(exp(x) + exp(y)) + // Numerically stable implementation: max(x, y) + log1p(exp(-abs(x - y))) + + // Handle NaN + if (isnan(*x) || isnan(*y)) { + return isnan(*x) ? *x : *y; + } + + // Handle infinities + // If both are -inf, result is -inf + if (isinf(*x) && *x < 0 && isinf(*y) && *y < 0) { + return -INFINITY; + } + + // If either is +inf, result is +inf + if ((isinf(*x) && *x > 0) || (isinf(*y) && *y > 0)) { + return INFINITY; + } + + // If one is -inf, result is the other value + if (isinf(*x) && *x < 0) { + return *y; + } + if (isinf(*y) && *y < 0) { + return *x; + } + + // Numerically stable computation + long double diff = *x - *y; + long double abs_diff = fabsl(diff); + long double max_val = (*x > *y) ? *x : *y; + return max_val + log1pl(expl(-abs_diff)); +} + // comparison quad functions typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *); diff --git a/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp b/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp index 8adfe4d..5a24e67 100644 --- a/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp +++ b/quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp @@ -240,5 +240,8 @@ init_quad_binary_ops(PyObject *numpy) if (create_quad_binary_ufunc(numpy, "copysign") < 0) { return -1; } + if (create_quad_binary_ufunc(numpy, "logaddexp") < 0) { + return -1; + } return 0; } \ No newline at end of file diff --git a/quaddtype/release_tracker.md b/quaddtype/release_tracker.md index 76186f6..1dc4de3 100644 --- a/quaddtype/release_tracker.md +++ b/quaddtype/release_tracker.md @@ -10,7 +10,7 @@ | multiply | ✅ | ✅ | | matmul | ✅ | ✅ | | divide | ✅ | ✅ | -| logaddexp | | | +| logaddexp | ✅ | ✅ | | logaddexp2 | | | | true_divide | | | | floor_divide | | | diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index a0b5335..0094c11 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -367,7 +367,7 @@ def test_logarithmic_functions(op, val): # Check sign for zero results if float_result == 0.0: assert np.signbit(float_result) == np.signbit( - quad_result), f"Zero sign mismatch for {op}({a}, {b})" + quad_result), f"Zero sign mismatch" @pytest.mark.parametrize("val", [ @@ -390,6 +390,7 @@ def test_logarithmic_functions(op, val): ]) def test_log1p(val): """Comprehensive test for log1p function""" + op = "log1p" quad_val = QuadPrecision(val) float_val = float(val) @@ -427,6 +428,106 @@ def test_log1p(val): assert np.signbit(float_result) == np.signbit( quad_result), f"Zero sign mismatch for {op}({val})" + +@pytest.mark.parametrize("x", [ + # Regular values + "0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5", + # Large values (test numerical stability) + "100.0", "1000.0", "-100.0", "-1000.0", + # Small values + "1e-10", "-1e-10", "1e-20", "-1e-20", + # Special values + "inf", "-inf", "nan", "-nan", "-0.0" +]) +@pytest.mark.parametrize("y", [ + # Regular values + "0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5", + # Large values + "100.0", "1000.0", "-100.0", "-1000.0", + # Small values + "1e-10", "-1e-10", "1e-20", "-1e-20", + # Special values + "inf", "-inf", "nan", "-nan", "-0.0" +]) +def test_logaddexp(x, y): + """Comprehensive test for logaddexp function: log(exp(x) + exp(y))""" + quad_x = QuadPrecision(x) + quad_y = QuadPrecision(y) + float_x = float(x) + float_y = float(y) + + quad_result = np.logaddexp(quad_x, quad_y) + float_result = np.logaddexp(float_x, float_y) + + # Handle NaN cases + if np.isnan(float_result): + assert np.isnan(float(quad_result)), \ + f"Expected NaN for logaddexp({x}, {y}), got {float(quad_result)}" + return + + # Handle infinity cases + if np.isinf(float_result): + assert np.isinf(float(quad_result)), \ + f"Expected inf for logaddexp({x}, {y}), got {float(quad_result)}" + if not np.isnan(float_result): + assert np.sign(float_result) == np.sign(float(quad_result)), \ + f"Infinity sign mismatch for logaddexp({x}, {y})" + return + + # For finite results, check with appropriate tolerance + # logaddexp is numerically sensitive, especially for large differences + if abs(float_x - float_y) > 50: + # When values differ greatly, result should be close to max(x, y) + rtol = 1e-10 + atol = 1e-10 + else: + rtol = 1e-13 + atol = 1e-15 + + np.testing.assert_allclose( + float(quad_result), float_result, + rtol=rtol, atol=atol, + err_msg=f"Value mismatch for logaddexp({x}, {y})" + ) + + +def test_logaddexp_special_properties(): + """Test special mathematical properties of logaddexp""" + # logaddexp(x, x) = x + log(2) + x = QuadPrecision("2.0") + result = np.logaddexp(x, x) + expected = float(x) + np.log(2.0) + np.testing.assert_allclose(float(result), expected, rtol=1e-14) + + # logaddexp(x, -inf) = x + x = QuadPrecision("5.0") + result = np.logaddexp(x, QuadPrecision("-inf")) + np.testing.assert_allclose(float(result), float(x), rtol=1e-14) + + # logaddexp(-inf, x) = x + result = np.logaddexp(QuadPrecision("-inf"), x) + np.testing.assert_allclose(float(result), float(x), rtol=1e-14) + + # logaddexp(-inf, -inf) = -inf + result = np.logaddexp(QuadPrecision("-inf"), QuadPrecision("-inf")) + assert np.isinf(float(result)) and float(result) < 0 + + # logaddexp(inf, anything) = inf + result = np.logaddexp(QuadPrecision("inf"), QuadPrecision("100.0")) + assert np.isinf(float(result)) and float(result) > 0 + + # logaddexp(anything, inf) = inf + result = np.logaddexp(QuadPrecision("100.0"), QuadPrecision("inf")) + assert np.isinf(float(result)) and float(result) > 0 + + # Commutativity: logaddexp(x, y) = logaddexp(y, x) + x = QuadPrecision("3.0") + y = QuadPrecision("5.0") + result1 = np.logaddexp(x, y) + result2 = np.logaddexp(y, x) + np.testing.assert_allclose(float(result1), float(result2), rtol=1e-14) + + def test_inf(): assert QuadPrecision("inf") > QuadPrecision("1e1000") assert np.signbit(QuadPrecision("inf")) == 0