Skip to content

Commit

Permalink
functions: add wrt kwarg for rewrite() helpers, use for Abs
Browse files Browse the repository at this point in the history
  • Loading branch information
skirpichev committed Apr 26, 2022
1 parent 54b2f52 commit d320669
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 37 deletions.
2 changes: 1 addition & 1 deletion diofant/core/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def _eval_rewrite(self, pattern, rule, **hints):

if pattern is None or isinstance(self, pattern):
if hasattr(self, rule):
rewritten = getattr(self, rule)(*args)
rewritten = getattr(self, rule)(*args, **hints)
if rewritten is not None:
return rewritten
return self.func(*args)
Expand Down
8 changes: 4 additions & 4 deletions diofant/functions/combinatorial/factorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _eval_rewrite_as_gamma(self, n):
from .. import gamma
return gamma(n + 1)

def _eval_rewrite_as_tractable(self, n):
def _eval_rewrite_as_tractable(self, n, **kwargs):
from .. import exp, loggamma
return exp(loggamma(n + 1))

Expand Down Expand Up @@ -350,7 +350,7 @@ def _eval_rewrite_as_gamma(self, x, k):
from .. import gamma
return gamma(x + k) / gamma(x)

def _eval_rewrite_as_tractable(self, x, k):
def _eval_rewrite_as_tractable(self, x, k, **kwargs):
return self._eval_rewrite_as_gamma(x, k).rewrite('tractable')

def _eval_is_integer(self):
Expand Down Expand Up @@ -554,14 +554,14 @@ def _eval_expand_func(self, **hints):

return self.func(*self.args)

def _eval_rewrite_as_factorial(self, n, k):
def _eval_rewrite_as_factorial(self, n, k, **kwargs):
return factorial(n)/(factorial(k)*factorial(n - k))

def _eval_rewrite_as_gamma(self, n, k):
from .. import gamma
return gamma(n + 1)/(gamma(k + 1)*gamma(n - k + 1))

def _eval_rewrite_as_tractable(self, n, k):
def _eval_rewrite_as_tractable(self, n, k, **kwargs):
return self._eval_rewrite_as_gamma(n, k).rewrite('tractable')

def _eval_is_integer(self):
Expand Down
5 changes: 2 additions & 3 deletions diofant/functions/combinatorial/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,10 @@ def eval(cls, n, sym=None):
'only for positive integer indices.')
return cls._fibpoly(n).subs({_sym: sym})

def _eval_rewrite_as_sqrt(self, n, sym=None):
def _eval_rewrite_as_sqrt(self, n, sym=None, **kwargs):
from .. import sqrt
if sym is None:
return (GoldenRatio**n - cos(pi*n)/GoldenRatio**n)/sqrt(5)

_eval_rewrite_as_tractable = _eval_rewrite_as_sqrt


Expand Down Expand Up @@ -672,7 +671,7 @@ def _eval_expand_func(self, **hints):

return self

def _eval_rewrite_as_tractable(self, n, m=1):
def _eval_rewrite_as_tractable(self, n, m=1, **kwargs):
from .. import polygamma
return self.rewrite(polygamma).rewrite('tractable', deep=True)

Expand Down
4 changes: 4 additions & 0 deletions diofant/functions/elementary/complexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,10 @@ def _eval_rewrite_as_Piecewise(self, arg):
def _eval_rewrite_as_sign(self, arg):
return arg/sign(arg)

def _eval_rewrite_as_tractable(self, arg, wrt=None, **kwargs):
if (s := sign(arg.limit(wrt, oo))) in (1, -1):
return s*arg


class arg(Function):
"""Returns the argument (in radians) of a complex number.
Expand Down
12 changes: 6 additions & 6 deletions diofant/functions/elementary/hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _eval_expand_trig(self, **hints):
return (sinh(x)*cosh(y) + sinh(y)*cosh(x)).expand(trig=True)
return sinh(arg)

def _eval_rewrite_as_exp(self, arg):
def _eval_rewrite_as_exp(self, arg, **kwargs):
return (exp(arg) - exp(-arg)) / 2
_eval_rewrite_as_tractable = _eval_rewrite_as_exp

Expand Down Expand Up @@ -288,7 +288,7 @@ def _eval_expand_trig(self, deep=True, **hints):
return (cosh(x)*cosh(y) + sinh(x)*sinh(y)).expand(trig=True)
return cosh(arg)

def _eval_rewrite_as_exp(self, arg):
def _eval_rewrite_as_exp(self, arg, **kwargs):
return (exp(arg) + exp(-arg)) / 2
_eval_rewrite_as_tractable = _eval_rewrite_as_exp

Expand Down Expand Up @@ -421,7 +421,7 @@ def as_real_imag(self, deep=True, **hints):
denom = sinh(re)**2 + cos(im)**2
return sinh(re)*cosh(re)/denom, sin(im)*cos(im)/denom

def _eval_rewrite_as_exp(self, arg):
def _eval_rewrite_as_exp(self, arg, **kwargs):
neg_exp, pos_exp = exp(-arg), exp(arg)
return (pos_exp - neg_exp)/(pos_exp + neg_exp)
_eval_rewrite_as_tractable = _eval_rewrite_as_exp
Expand Down Expand Up @@ -546,7 +546,7 @@ def as_real_imag(self, deep=True, **hints):
denom = sinh(re)**2 + sin(im)**2
return sinh(re)*cosh(re)/denom, -sin(im)*cos(im)/denom

def _eval_rewrite_as_exp(self, arg):
def _eval_rewrite_as_exp(self, arg, **kwargs):
neg_exp, pos_exp = exp(-arg), exp(arg)
return (pos_exp + neg_exp)/(pos_exp - neg_exp)
_eval_rewrite_as_tractable = _eval_rewrite_as_exp
Expand Down Expand Up @@ -904,7 +904,7 @@ def inverse(self, argindex=1):
"""Returns the inverse of this function."""
return cosh

def _eval_rewrite_as_log(self, x):
def _eval_rewrite_as_log(self, x, **hints):
return log(x + sqrt(x - 1)*sqrt(x + 1))
_eval_rewrite_as_tractable = _eval_rewrite_as_log

Expand Down Expand Up @@ -1054,7 +1054,7 @@ def _eval_as_leading_term(self, x):
else:
return self.func(arg)

def _eval_rewrite_as_log(self, x):
def _eval_rewrite_as_log(self, x, **kwargs):
return (log((x + 1)/x) - log((x - 1)/x))/2
_eval_rewrite_as_tractable = _eval_rewrite_as_log

Expand Down
8 changes: 4 additions & 4 deletions diofant/functions/special/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def _eval_rewrite_as_hyper(self, z):
pf2 = z / (root(3, 3)*gamma(Rational(1, 3)))
return pf1 * hyper([], [Rational(2, 3)], z**3/9) - pf2 * hyper([], [Rational(4, 3)], z**3/9)

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return exp(-Rational(2, 3)*z**Rational(3, 2))*sqrt(pi*sqrt(z))/2*_airyais(z)

def _eval_expand_func(self, **hints):
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def _eval_rewrite_as_hyper(self, z):
pf2 = z*root(3, 6) / gamma(Rational(1, 3))
return pf1 * hyper([], [Rational(2, 3)], z**3/9) + pf2 * hyper([], [Rational(4, 3)], z**3/9)

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return exp(Rational(2, 3)*z**Rational(3, 2))*sqrt(pi*sqrt(z))*_airybis(z)

def _eval_expand_func(self, **hints):
Expand Down Expand Up @@ -1065,7 +1065,7 @@ def _eval_expand_func(self, **hints):


class _airyais(Function):
def _eval_rewrite_as_intractable(self, x):
def _eval_rewrite_as_intractable(self, x, **kwargs):
return 2*airyai(x)*exp(Rational(2, 3)*x**Rational(3, 2))/sqrt(pi*sqrt(x))

def _eval_aseries(self, n, args0, x, logx):
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def _eval_nseries(self, x, n, logx):


class _airybis(Function):
def _eval_rewrite_as_intractable(self, x):
def _eval_rewrite_as_intractable(self, x, **kwargs):
return airybi(x)*exp(-Rational(2, 3)*x**Rational(3, 2))/sqrt(pi*sqrt(x))

def _eval_aseries(self, n, args0, x, logx):
Expand Down
14 changes: 7 additions & 7 deletions diofant/functions/special/error_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _eval_rewrite_as_hyper(self, z):
def _eval_rewrite_as_expint(self, z):
return sqrt(z**2)/z - z*expint(Rational(1, 2), z**2)/sqrt(pi)

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return 1 - _erfs(z)*exp(-z**2)

def _eval_rewrite_as_erfc(self, z):
Expand Down Expand Up @@ -339,7 +339,7 @@ def _eval_is_extended_real(self):
elif arg.is_imaginary and arg.is_nonzero:
return False

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return self.rewrite(erf).rewrite('tractable', deep=True)

def _eval_rewrite_as_erf(self, z):
Expand Down Expand Up @@ -517,7 +517,7 @@ def _eval_is_extended_real(self):
elif arg.is_imaginary and arg.is_nonzero:
return False

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return self.rewrite(erf).rewrite('tractable', deep=True)

def _eval_rewrite_as_erf(self, z):
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def _eval_rewrite_as_uppergamma(self, z):
# immediately turns into expint
return -uppergamma(0, polar_lift(-1)*z) - I*pi

def _eval_rewrite_as_expint(self, z):
def _eval_rewrite_as_expint(self, z, **kwargs):
return -expint(1, polar_lift(-1)*z) - I*pi

def _eval_rewrite_as_li(self, z):
Expand All @@ -1072,7 +1072,7 @@ def _eval_rewrite_as_Si(self, z):
_eval_rewrite_as_Chi = _eval_rewrite_as_Si
_eval_rewrite_as_Shi = _eval_rewrite_as_Si

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return exp(z) * _eis(z)

def _eval_nseries(self, x, n, logx):
Expand Down Expand Up @@ -1404,7 +1404,7 @@ def _eval_rewrite_as_meijerg(self, z):
return (-log(-log(z)) - (log(1/log(z)) - log(log(z)))/2
- meijerg(((), (1,)), ((0, 0), ()), -log(z)))

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return z * _eis(log(z))


Expand Down Expand Up @@ -1484,7 +1484,7 @@ def _eval_evalf(self, prec):
def _eval_rewrite_as_li(self, z):
return li(z) - li(2)

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return self.rewrite(li).rewrite('tractable', deep=True)

###############################################################################
Expand Down
6 changes: 3 additions & 3 deletions diofant/functions/special/gamma_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _eval_is_positive(self):
elif x.is_noninteger:
return floor(x).is_even

def _eval_rewrite_as_tractable(self, z):
def _eval_rewrite_as_tractable(self, z, **kwargs):
return exp(loggamma(z))

def _eval_rewrite_as_factorial(self, z):
Expand Down Expand Up @@ -316,7 +316,7 @@ def _eval_conjugate(self):
def _eval_rewrite_as_uppergamma(self, s, x):
return gamma(s) - uppergamma(s, x)

def _eval_rewrite_as_tractable(self, s, x):
def _eval_rewrite_as_tractable(self, s, x, **kwargs):
return self.rewrite(uppergamma)

def _eval_rewrite_as_expint(self, s, x):
Expand Down Expand Up @@ -878,7 +878,7 @@ def _eval_aseries(self, n, args0, x, logx):
# It is very inefficient to first add the order and then do the nseries
return (r + Add(*l))._eval_nseries(x, n, logx) + o

def _eval_rewrite_as_intractable(self, z):
def _eval_rewrite_as_intractable(self, z, **kwargs):
return log(gamma(z))

def _eval_is_extended_real(self):
Expand Down
4 changes: 2 additions & 2 deletions diofant/functions/special/zeta_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,13 @@ def fdiff(self, argindex=1):
else:
raise ArgumentIndexError

def _eval_rewrite_as_tractable(self, s, a=1):
def _eval_rewrite_as_tractable(self, s, a=1, **kwargs):
if len(self.args) == 1:
return _zetas(exp(s))


class _zetas(Function):
def _eval_rewrite_as_intractable(self, s):
def _eval_rewrite_as_intractable(self, s, **kwargs):
return zeta(log(s))

def _eval_aseries(self, n, args0, x, logx):
Expand Down
2 changes: 1 addition & 1 deletion diofant/series/gruntz.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def limitinf(e, x):
assert not e.has(Float)

# Rewrite e in terms of tractable functions only:
e = e.rewrite('tractable', deep=True)
e = e.rewrite('tractable', deep=True, wrt=x)

if not e.has(x):
# This is a bit of a heuristic for nice results. We always rewrite
Expand Down
7 changes: 1 addition & 6 deletions diofant/series/limits.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..core import Dummy, Expr, Float, PoleError, Rational, nan, oo, sympify
from ..core.function import UndefinedFunction
from ..functions import Abs, cos, sign, sin
from ..functions import cos, sign, sin
from ..sets import Reals
from .gruntz import limitinf
from .order import Order
Expand Down Expand Up @@ -173,10 +173,6 @@ def doit(self, **hints):
return has_oo
raise NotImplementedError

def tr_abs(f):
s = sign(limit(f.args[0], z, oo))
return s*f.args[0] if s in (1, -1) else f

def tr_Piecewise(f):
for a, c in f.args:
if not c.is_Atom:
Expand All @@ -188,7 +184,6 @@ def tr_Piecewise(f):
break
return a

e = e.replace(lambda f: isinstance(f, Abs) and f.has(z), tr_abs)
e = e.replace(lambda f: f.is_Piecewise and f.has(z), tr_Piecewise)

try:
Expand Down

0 comments on commit d320669

Please sign in to comment.