Skip to content

Commit ae618c9

Browse files
authored
Merge pull request #184 from SwayamInSync/cbrt
2 parents 809c6da + 7329e50 commit ae618c9

File tree

4 files changed

+119
-1
lines changed

4 files changed

+119
-1
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,40 @@ quad_sqrt(const Sleef_quad *op)
7676
return Sleef_sqrtq1_u05(*op);
7777
}
7878

79+
static inline Sleef_quad
80+
quad_cbrt(const Sleef_quad *op)
81+
{
82+
// SLEEF doesn't provide cbrt, so we implement it using pow
83+
// cbrt(x) = x^(1/3)
84+
// For negative values: cbrt(-x) = -cbrt(x)
85+
86+
// Handle special cases
87+
if (Sleef_iunordq1(*op, *op)) {
88+
return *op; // NaN
89+
}
90+
if (Sleef_icmpeqq1(*op, QUAD_ZERO)) {
91+
return *op; // ±0
92+
}
93+
// Check if op is ±inf: isinf(x) = abs(x) == inf
94+
if (Sleef_icmpeqq1(Sleef_fabsq1(*op), QUAD_POS_INF)) {
95+
return *op; // ±inf
96+
}
97+
98+
// Compute 1/3 as a quad precision constant
99+
Sleef_quad three = Sleef_cast_from_int64q1(3);
100+
Sleef_quad one_third = Sleef_divq1_u05(QUAD_ONE, three);
101+
102+
// Handle negative values: cbrt(-x) = -cbrt(x)
103+
if (Sleef_icmpltq1(*op, QUAD_ZERO)) {
104+
Sleef_quad abs_val = Sleef_fabsq1(*op);
105+
Sleef_quad result = Sleef_powq1_u10(abs_val, one_third);
106+
return Sleef_negq1(result);
107+
}
108+
109+
// Positive values
110+
return Sleef_powq1_u10(*op, one_third);
111+
}
112+
79113
static inline Sleef_quad
80114
quad_square(const Sleef_quad *op)
81115
{
@@ -266,6 +300,12 @@ ld_sqrt(const long double *op)
266300
return sqrtl(*op);
267301
}
268302

303+
static inline long double
304+
ld_cbrt(const long double *op)
305+
{
306+
return cbrtl(*op);
307+
}
308+
269309
static inline long double
270310
ld_square(const long double *op)
271311
{

quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ init_quad_unary_ops(PyObject *numpy)
179179
if (create_quad_unary_ufunc<quad_sqrt, ld_sqrt>(numpy, "sqrt") < 0) {
180180
return -1;
181181
}
182+
if (create_quad_unary_ufunc<quad_cbrt, ld_cbrt>(numpy, "cbrt") < 0) {
183+
return -1;
184+
}
182185
if (create_quad_unary_ufunc<quad_square, ld_square>(numpy, "square") < 0) {
183186
return -1;
184187
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
| log1p |||
3939
| sqrt |||
4040
| square |||
41-
| cbrt | | |
41+
| cbrt | | |
4242
| reciprocal |||
4343
| gcd | | |
4444
| lcm | | |

quaddtype/tests/test_quaddtype.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,81 @@ def test_rint_near_halfway():
253253
assert np.rint(QuadPrecision("7.5")) == 8
254254

255255

256+
@pytest.mark.parametrize("val", [
257+
# Perfect cubes
258+
"1.0", "8.0", "27.0", "64.0", "125.0", "1000.0",
259+
# Negative perfect cubes
260+
"-1.0", "-8.0", "-27.0", "-64.0", "-125.0", "-1000.0",
261+
# Small positive values
262+
"0.001", "0.008", "0.027", "1e-9", "1e-15", "1e-100",
263+
# Small negative values
264+
"-0.001", "-0.008", "-0.027", "-1e-9", "-1e-15", "-1e-100",
265+
# Large positive values
266+
"1e10", "1e15", "1e100", "1e300",
267+
# Large negative values
268+
"-1e10", "-1e15", "-1e100", "-1e300",
269+
# Fractional values
270+
"0.5", "2.5", "3.5", "10.5", "100.5",
271+
"-0.5", "-2.5", "-3.5", "-10.5", "-100.5",
272+
# Edge cases
273+
"0.0", "-0.0",
274+
# Special values
275+
"inf", "-inf", "nan", "-nan"
276+
])
277+
def test_cbrt(val):
278+
"""Comprehensive test for cube root function"""
279+
quad_val = QuadPrecision(val)
280+
float_val = float(val)
281+
282+
quad_result = np.cbrt(quad_val)
283+
float_result = np.cbrt(float_val)
284+
285+
# Handle NaN cases
286+
if np.isnan(float_result):
287+
assert np.isnan(
288+
float(quad_result)), f"Expected NaN for cbrt({val}), got {float(quad_result)}"
289+
return
290+
291+
# Handle infinity cases
292+
if np.isinf(float_result):
293+
assert np.isinf(
294+
float(quad_result)), f"Expected inf for cbrt({val}), got {float(quad_result)}"
295+
assert np.sign(float_result) == np.sign(
296+
float(quad_result)), f"Infinity sign mismatch for cbrt({val})"
297+
return
298+
299+
# For finite results, check value and sign
300+
# Use relative tolerance for cbrt
301+
if float_result != 0.0:
302+
rtol = 1e-14 if abs(float_result) < 1e100 else 1e-10
303+
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=1e-15,
304+
err_msg=f"Value mismatch for cbrt({val})")
305+
else:
306+
# For zero results
307+
assert float(quad_result) == 0.0, f"Expected 0 for cbrt({val}), got {float(quad_result)}"
308+
assert np.signbit(float_result) == np.signbit(
309+
quad_result), f"Zero sign mismatch for cbrt({val})"
310+
311+
312+
def test_cbrt_accuracy():
313+
"""Test that cbrt gives accurate results for perfect cubes"""
314+
# Test perfect cubes
315+
for i in [1, 2, 3, 4, 5, 10, 100]:
316+
val = QuadPrecision(i ** 3)
317+
result = np.cbrt(val)
318+
expected = QuadPrecision(i)
319+
np.testing.assert_allclose(float(result), float(expected), rtol=1e-14, atol=1e-15,
320+
err_msg=f"cbrt({i}^3) should equal {i}")
321+
322+
# Test negative perfect cubes
323+
for i in [1, 2, 3, 4, 5, 10, 100]:
324+
val = QuadPrecision(-(i ** 3))
325+
result = np.cbrt(val)
326+
expected = QuadPrecision(-i)
327+
np.testing.assert_allclose(float(result), float(expected), rtol=1e-14, atol=1e-15,
328+
err_msg=f"cbrt(-{i}^3) should equal -{i}")
329+
330+
256331
@pytest.mark.parametrize("op", ["exp", "exp2"])
257332
@pytest.mark.parametrize("val", [
258333
# Basic cases

0 commit comments

Comments
 (0)