diff --git a/quaddtype/numpy_quaddtype/src/ops.hpp b/quaddtype/numpy_quaddtype/src/ops.hpp index dd385d9..170cff3 100644 --- a/quaddtype/numpy_quaddtype/src/ops.hpp +++ b/quaddtype/numpy_quaddtype/src/ops.hpp @@ -124,6 +124,12 @@ quad_exp2(const Sleef_quad *op) return Sleef_exp2q1_u10(*op); } +static inline Sleef_quad +quad_expm1(const Sleef_quad *op) +{ + return Sleef_expm1q1_u10(*op); +} + static inline Sleef_quad quad_sin(const Sleef_quad *op) { @@ -308,6 +314,12 @@ ld_exp2(const long double *op) return exp2l(*op); } +static inline long double +ld_expm1(const long double *op) +{ + return expm1l(*op); +} + static inline long double ld_sin(const long double *op) { diff --git a/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp b/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp index b8a82ae..94468ad 100644 --- a/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp +++ b/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp @@ -203,6 +203,9 @@ init_quad_unary_ops(PyObject *numpy) if (create_quad_unary_ufunc(numpy, "exp2") < 0) { return -1; } + if (create_quad_unary_ufunc(numpy, "expm1") < 0) { + return -1; + } if (create_quad_unary_ufunc(numpy, "sin") < 0) { return -1; } diff --git a/quaddtype/release_tracker.md b/quaddtype/release_tracker.md index cb61ecb..7a5e8c6 100644 --- a/quaddtype/release_tracker.md +++ b/quaddtype/release_tracker.md @@ -34,7 +34,7 @@ | log | ✅ | ✅ | | log2 | ✅ | ✅ | | log10 | ✅ | ✅ | -| expm1 | | | +| expm1 | ✅ | ✅ | | log1p | ✅ | ✅ | | sqrt | ✅ | ✅ | | square | ✅ | ✅ | diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index b63daee..1e2dd51 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -429,6 +429,75 @@ def test_log1p(val): quad_result), f"Zero sign mismatch for {op}({val})" +@pytest.mark.parametrize("val", [ + # Cases close to 0 (where expm1 is most accurate and important) + "0.0", "-0.0", + "1e-10", "-1e-10", "1e-15", "-1e-15", "1e-20", "-1e-20", + "1e-100", "-1e-100", "1e-300", "-1e-300", + # Small values + "0.001", "-0.001", "0.01", "-0.01", "0.1", "-0.1", + # Moderate values + "0.5", "-0.5", "1.0", "-1.0", "2.0", "-2.0", + # Larger values + "5.0", "-5.0", "10.0", "-10.0", "20.0", "-20.0", + # Values that test exp behavior + "50.0", "-50.0", "100.0", "-100.0", + # Large positive values (exp(x) grows rapidly) + "200.0", "500.0", "700.0", + # Large negative values (should approach -1) + "-200.0", "-500.0", "-700.0", "-1000.0", + # Special values + "inf", # Should give inf + "-inf", # Should give -1 + "nan", "-nan" +]) +def test_expm1(val): + """Comprehensive test for expm1 function: exp(x) - 1 + + This function provides greater precision than exp(x) - 1 for small values of x. + """ + quad_val = QuadPrecision(val) + float_val = float(val) + + quad_result = np.expm1(quad_val) + float_result = np.expm1(float_val) + + # Handle NaN cases + if np.isnan(float_result): + assert np.isnan( + float(quad_result)), f"Expected NaN for expm1({val}), got {float(quad_result)}" + return + + # Handle infinity cases + if np.isinf(float_result): + assert np.isinf( + float(quad_result)), f"Expected inf for expm1({val}), got {float(quad_result)}" + assert np.sign(float_result) == np.sign( + float(quad_result)), f"Infinity sign mismatch for expm1({val})" + return + + # For finite results + # expm1 is designed for high accuracy near 0, so use tight tolerances for small inputs + if abs(float(val)) < 1e-10: + rtol = 1e-15 + atol = 1e-20 + elif abs(float_result) < 1: + rtol = 1e-14 + atol = 1e-15 + else: + # For larger results, use relative tolerance + rtol = 1e-14 + atol = 1e-15 + + np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol, + err_msg=f"Value mismatch for expm1({val})") + + # Check sign for zero results + if float_result == 0.0: + assert np.signbit(float_result) == np.signbit( + quad_result), f"Zero sign mismatch for expm1({val})" + + @pytest.mark.parametrize("x", [ # Regular values "0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",