Skip to content

Commit

Permalink
Use mpmath's floor/ceil to calculate round/ceiling
Browse files Browse the repository at this point in the history
Drop get_integer_part() helper function.

Fixes sympy/sympy#10323
  • Loading branch information
skirpichev committed Jan 1, 2016
1 parent dacbfab commit cc427e3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 81 deletions.
71 changes: 0 additions & 71 deletions sympy/core/evalf.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,75 +283,6 @@ def check_target(expr, result, prec):
"a higher maxn for evalf" % (expr))


def get_integer_part(expr, no, options, return_ints=False):
"""
With no = 1, computes ceiling(expr)
With no = -1, computes floor(expr)
Note: this function either gives the exact result or signals failure.
"""
import sympy
# The expression is likely less than 2^30 or so
assumed_size = 30
ire, iim, ire_acc, iim_acc = evalf(expr, assumed_size, options)

# We now know the size, so we can calculate how much extra precision
# (if any) is needed to get within the nearest integer
if ire and iim:
gap = max(fastlog(ire) - ire_acc, fastlog(iim) - iim_acc)
elif ire:
gap = fastlog(ire) - ire_acc
elif iim:
gap = fastlog(iim) - iim_acc
else:
# ... or maybe the expression was exactly zero
return None, None, None, None

margin = 10

if gap >= -margin:
ire, iim, ire_acc, iim_acc = \
evalf(expr, margin + assumed_size + gap, options)

# We can now easily find the nearest integer, but to find floor/ceil, we
# must also calculate whether the difference to the nearest integer is
# positive or negative (which may fail if very close).
def calc_part(expr, nexpr):
from sympy import Add
nint = int(to_int(nexpr, rnd))
n, c, p, b = nexpr
if (c != 1 and p != 0) or p < 0:
expr = Add(expr, -nint, evaluate=False)
x, _, x_acc, _ = evalf(expr, 10, options)
try:
check_target(expr, (x, None, x_acc, None), 3)
except PrecisionExhausted:
if not expr.equals(0):
raise PrecisionExhausted
x = fzero
nint += int(no*(mpf_cmp(x or fzero, fzero) == no))
nint = from_int(nint)
return nint, fastlog(nint) + 10

re, im, re_acc, im_acc = None, None, None, None

if ire:
re, re_acc = calc_part(sympy.re(expr, evaluate=False), ire)
if iim:
im, im_acc = calc_part(sympy.im(expr, evaluate=False), iim)

if return_ints:
return int(to_int(re or fzero)), int(to_int(im or fzero))
return re, im, re_acc, im_acc


def evalf_ceiling(expr, prec, options):
return get_integer_part(expr.args[0], 1, options)


def evalf_floor(expr, prec, options):
return get_integer_part(expr.args[0], -1, options)

############################################################################
# #
# Arithmetic operations #
Expand Down Expand Up @@ -1228,8 +1159,6 @@ def _create_evalf_table():

re: evalf_re,
im: evalf_im,
floor: evalf_floor,
ceiling: evalf_ceiling,

Integral: evalf_integral,
Sum: evalf_sum,
Expand Down
22 changes: 16 additions & 6 deletions sympy/core/tests/test_evalf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
integrate, log, Mul, N, oo, pi, Pow, product, Product,
Rational, S, Sum, sin, sqrt, sstr, sympify, Symbol)
from sympy.core.evalf import (complex_accuracy, PrecisionExhausted,
scaled_zero, get_integer_part, as_mpmath)
scaled_zero, as_mpmath)

from sympy.abc import n, x, y

Expand Down Expand Up @@ -254,6 +254,21 @@ def test_evalf_integer_parts():
assert ceiling(x).evalf(subs={x: 3*I}) == 3*I
assert ceiling(x).evalf(subs={x: 2 + 3*I}) == 2 + 3*I

# issue sympy/sympy#10323
l = 1206577996382235787095214
y = ceiling(sqrt(l))
assert y == 1098443442506
assert y**2 >= l

def check(x):
c, f = ceiling(sqrt(x)), floor(sqrt(x))
assert (c - 1)**2 < x and c**2 >= x
assert (f + 1)**2 > x and f**2 <= x

check(2**30 + 1)
check(2**100 + 1)
check(2**112 + 2)


def test_evalf_trig_zero_detection():
a = sin(160*pi, evaluate=False)
Expand Down Expand Up @@ -471,11 +486,6 @@ def test_issue_8853():
assert ceiling(p - S.Half).is_even
assert ceiling(p + S.Half).is_even is False

assert get_integer_part(S.Half, -1, {}, True) == (0, 0)
assert get_integer_part(S.Half, 1, {}, True) == (1, 0)
assert get_integer_part(-S.Half, -1, {}, True) == (-1, 0)
assert get_integer_part(-S.Half, 1, {}, True) == (0, 0)


def test_issue_9326():
from sympy import Dummy
Expand Down
14 changes: 11 additions & 3 deletions sympy/functions/elementary/integers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sympy.core.singleton import S
from sympy.core.function import Function
from sympy.core import Add
from sympy.core.evalf import get_integer_part, PrecisionExhausted
from sympy.core.evalf import PrecisionExhausted
from sympy.core.numbers import Integer
from sympy.core.relational import Gt, Lt, Ge, Le
from sympy.core.symbol import Symbol
Expand Down Expand Up @@ -53,8 +53,16 @@ def eval(cls, arg):
npart.is_extended_real and (spart.is_imaginary or (S.ImaginaryUnit*spart).is_extended_real) or
npart.is_imaginary and spart.is_extended_real):
try:
r, i = get_integer_part(
npart, cls._dir, {}, return_ints=True)
from sympy.core.evalf import DEFAULT_MAXPREC as target
prec = 10
while True:
r, i = cls(npart, evaluate=False).evalf(prec).as_real_imag()
if 2**prec > max(abs(int(r)), abs(int(i))) + 10:
break
else:
if prec >= target:
raise PrecisionExhausted
prec += 10
ipart += Integer(r) + Integer(i)*S.ImaginaryUnit
npart = S.Zero
except (PrecisionExhausted, NotImplementedError):
Expand Down
3 changes: 2 additions & 1 deletion sympy/utilities/lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
"Shi": "shi",
"Chi": "chi",
"Si": "si",
"Ci": "ci"
"Ci": "ci",
"ceiling": "ceil",
}

NUMPY_TRANSLATIONS = {
Expand Down

0 comments on commit cc427e3

Please sign in to comment.