Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions quaddtype/numpy_quaddtype/src/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ quad_exp2(const Sleef_quad *op)
return Sleef_exp2q1_u10(*op);
}

static inline Sleef_quad
quad_expm1(const Sleef_quad *op)
{
return Sleef_expm1q1_u10(*op);
}

static inline Sleef_quad
quad_sin(const Sleef_quad *op)
{
Expand Down Expand Up @@ -308,6 +314,12 @@ ld_exp2(const long double *op)
return exp2l(*op);
}

static inline long double
ld_expm1(const long double *op)
{
return expm1l(*op);
}

static inline long double
ld_sin(const long double *op)
{
Expand Down
3 changes: 3 additions & 0 deletions quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ init_quad_unary_ops(PyObject *numpy)
if (create_quad_unary_ufunc<quad_exp2, ld_exp2>(numpy, "exp2") < 0) {
return -1;
}
if (create_quad_unary_ufunc<quad_expm1, ld_expm1>(numpy, "expm1") < 0) {
return -1;
}
if (create_quad_unary_ufunc<quad_sin, ld_sin>(numpy, "sin") < 0) {
return -1;
}
Expand Down
2 changes: 1 addition & 1 deletion quaddtype/release_tracker.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
| log | ✅ | ✅ |
| log2 | ✅ | ✅ |
| log10 | ✅ | ✅ |
| expm1 | | |
| expm1 | | ✅ |
| log1p | ✅ | ✅ |
| sqrt | ✅ | ✅ |
| square | ✅ | ✅ |
Expand Down
69 changes: 69 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,75 @@ def test_log1p(val):
quad_result), f"Zero sign mismatch for {op}({val})"


@pytest.mark.parametrize("val", [
# Cases close to 0 (where expm1 is most accurate and important)
"0.0", "-0.0",
"1e-10", "-1e-10", "1e-15", "-1e-15", "1e-20", "-1e-20",
"1e-100", "-1e-100", "1e-300", "-1e-300",
# Small values
"0.001", "-0.001", "0.01", "-0.01", "0.1", "-0.1",
# Moderate values
"0.5", "-0.5", "1.0", "-1.0", "2.0", "-2.0",
# Larger values
"5.0", "-5.0", "10.0", "-10.0", "20.0", "-20.0",
# Values that test exp behavior
"50.0", "-50.0", "100.0", "-100.0",
# Large positive values (exp(x) grows rapidly)
"200.0", "500.0", "700.0",
# Large negative values (should approach -1)
"-200.0", "-500.0", "-700.0", "-1000.0",
# Special values
"inf", # Should give inf
"-inf", # Should give -1
"nan", "-nan"
])
def test_expm1(val):
"""Comprehensive test for expm1 function: exp(x) - 1

This function provides greater precision than exp(x) - 1 for small values of x.
"""
quad_val = QuadPrecision(val)
float_val = float(val)

quad_result = np.expm1(quad_val)
float_result = np.expm1(float_val)

# Handle NaN cases
if np.isnan(float_result):
assert np.isnan(
float(quad_result)), f"Expected NaN for expm1({val}), got {float(quad_result)}"
return

# Handle infinity cases
if np.isinf(float_result):
assert np.isinf(
float(quad_result)), f"Expected inf for expm1({val}), got {float(quad_result)}"
assert np.sign(float_result) == np.sign(
float(quad_result)), f"Infinity sign mismatch for expm1({val})"
return

# For finite results
# expm1 is designed for high accuracy near 0, so use tight tolerances for small inputs
if abs(float(val)) < 1e-10:
rtol = 1e-15
atol = 1e-20
elif abs(float_result) < 1:
rtol = 1e-14
atol = 1e-15
else:
# For larger results, use relative tolerance
rtol = 1e-14
atol = 1e-15

np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol,
err_msg=f"Value mismatch for expm1({val})")

# Check sign for zero results
if float_result == 0.0:
assert np.signbit(float_result) == np.signbit(
quad_result), f"Zero sign mismatch for expm1({val})"


@pytest.mark.parametrize("x", [
# Regular values
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
Expand Down
Loading