Skip to content

Commit

Permalink
Drop multivariate Order notion
Browse files Browse the repository at this point in the history
Closes #1228
  • Loading branch information
skirpichev committed Mar 2, 2023
1 parent 9289ae0 commit ca68bb1
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 200 deletions.
81 changes: 26 additions & 55 deletions diofant/calculus/order.py
Expand Up @@ -2,7 +2,6 @@
expand_log, expand_power_base, nan, oo, sympify)
from ..core.compatibility import is_sequence
from ..utilities import default_sort_key
from ..utilities.iterables import uniq


class Order(Expr):
Expand Down Expand Up @@ -72,13 +71,6 @@ class Order(Expr):
>>> O(cos(x), (x, pi/2))
O(x - pi/2, (x, pi/2))
>>> O(1 + x*y)
O(1, x, y)
>>> O(1 + x*y, (x, 0), (y, 0))
O(1, x, y)
>>> O(1 + x*y, (x, oo), (y, oo))
O(x*y, (x, oo), (y, oo))
References
==========
Expand Down Expand Up @@ -111,12 +103,12 @@ def __new__(cls, expr, *args, **kwargs):
variables = list(map(sympify, args))
point = [Integer(0)]*len(variables)

if len(variables) > 1:
raise TypeError

if not all(isinstance(v, (Dummy, Symbol)) for v in variables):
raise TypeError(f'Variables are not symbols, got {variables}')

if len(list(uniq(variables))) != len(variables):
raise ValueError(f'Variables are supposed to be unique symbols, got {variables}')

if expr.is_Order:
expr_vp = dict(expr.args[1:])
new_vp = dict(expr_vp)
Expand Down Expand Up @@ -163,14 +155,6 @@ def __new__(cls, expr, *args, **kwargs):
else:
args = tuple(variables)

if len(variables) > 1:
# XXX: better way? We need this expand() to
# workaround e.g: expr = x*(x + y).
# (x*(x + y)).as_leading_term(x, y) currently returns
# x*y (wrong order term!). That's why we want to deal with
# expand()'ed expr (handled in "if expr.is_Add" branch below).
expr = expr.expand()

if expr.is_Add:
lst = expr.extract_leading_order(args)
expr = Add(*[f.expr for (e, f) in lst])
Expand All @@ -182,42 +166,38 @@ def __new__(cls, expr, *args, **kwargs):
expr = expand_power_base(expr)
expr = expand_log(expr)

if len(args) == 1:
# The definition of O(f(x)) symbol explicitly stated that
# the argument of f(x) is irrelevant. That's why we can
# combine some power exponents (only "on top" of the
# expression tree for f(x)), e.g.:
# x**p * (-x)**q -> x**(p+q) for real p, q.
x = args[0]
margs = list(Mul.make_args(
expr.as_independent(x, as_Add=False)[1]))

for i, t in enumerate(margs):
if t.is_Pow:
b, q = t.base, t.exp
if b in (x, -x) and q.is_extended_real and not q.has(x):
margs[i] = x**q
elif b.is_Pow and not b.exp.has(x):
# The definition of O(f(x)) symbol explicitly stated that
# the argument of f(x) is irrelevant. That's why we can
# combine some power exponents (only "on top" of the
# expression tree for f(x)), e.g.:
# x**p * (-x)**q -> x**(p+q) for real p, q.
x = args[0]
margs = list(Mul.make_args(
expr.as_independent(x, as_Add=False)[1]))

for i, t in enumerate(margs):
if t.is_Pow:
b, q = t.base, t.exp
if b in (x, -x) and q.is_extended_real and not q.has(x):
margs[i] = x**q
elif b.is_Pow and not b.exp.has(x):
b, r = b.base, b.exp
if b in (x, -x) and r.is_extended_real:
margs[i] = x**(r*q)
elif b.is_Mul and b.args[0] == -1:
b = -b
if b.is_Pow and not b.exp.has(x):
b, r = b.base, b.exp
if b in (x, -x) and r.is_extended_real:
margs[i] = x**(r*q)
elif b.is_Mul and b.args[0] == -1:
b = -b
if b.is_Pow and not b.exp.has(x):
b, r = b.base, b.exp
if b in (x, -x) and r.is_extended_real:
margs[i] = x**(r*q)

expr = Mul(*margs)
expr = Mul(*margs)

expr = expr.subs(rs)

if expr == 0:
return expr

if expr.is_Order:
expr = expr.expr

if not expr.has(*variables):
expr = Integer(1)

Expand Down Expand Up @@ -321,9 +301,6 @@ def contains(self, expr):
return all(x in self.args[1:] for x in expr.args[1:])
if expr.expr.is_Add:
return all(self.contains(x) for x in expr.expr.args)
if self.expr.is_Add and point == 0:
return any(self.func(x, *self.args[1:]).contains(expr)
for x in self.expr.args)
if self.variables and expr.variables:
common_symbols = tuple(s for s in self.variables if s in expr.variables)
elif self.variables:
Expand All @@ -332,7 +309,6 @@ def contains(self, expr):
common_symbols = expr.variables
if not common_symbols:
return
r = None
ratio = self.expr/expr.expr
ratio = powsimp(ratio, deep=True, combine='exp')
for s in common_symbols:
Expand All @@ -341,12 +317,7 @@ def contains(self, expr):
l = l != 0
else:
l = None
if r is None:
r = l
else:
if r != l:
return
return r
return l
obj = self.func(expr, *self.args[1:])
return self.contains(obj)

Expand Down
11 changes: 2 additions & 9 deletions diofant/core/expr.py
Expand Up @@ -2702,7 +2702,7 @@ def compute_leading_term(self, x, logx=None):
return t.as_leading_term(x)

@cacheit
def as_leading_term(self, *symbols):
def as_leading_term(self, x):
"""Returns the leading (nonzero) term of the series expansion of self.
The _eval_as_leading_term routines are used to do this, and they must
Expand All @@ -2718,14 +2718,7 @@ def as_leading_term(self, *symbols):
"""
from ..simplify import powsimp
if len(symbols) > 1:
c = self
for x in symbols:
c = c.as_leading_term(x)
return c
if not symbols:
return self
x = sympify(symbols[0])
x = sympify(x)
if not x.is_Symbol:
raise ValueError(f'expecting a Symbol but got {x}')
if x not in self.free_symbols:
Expand Down
13 changes: 3 additions & 10 deletions diofant/printing/latex.py
Expand Up @@ -1105,18 +1105,11 @@ def _print_Rational(self, expr):

def _print_Order(self, expr):
s = self._print(expr.expr)
if expr.point and any(p != 0 for p in expr.point) or \
len(expr.variables) > 1:
if expr.point and any(p != 0 for p in expr.point):
s += '; '
if len(expr.variables) > 1:
s += self._print(expr.variables)
else:
s += self._print(expr.variables[0])
s += self._print(expr.variables[0])
s += r'\rightarrow{}'
if len(expr.point) > 1:
s += self._print(expr.point)
else:
s += self._print(expr.point[0])
s += self._print(expr.point[0])
return r'\mathcal{O}\left(%s\right)' % s

def _print_Symbol(self, expr):
Expand Down
13 changes: 3 additions & 10 deletions diofant/printing/pretty/pretty.py
Expand Up @@ -935,21 +935,14 @@ def _print_Lambda(self, e):

def _print_Order(self, expr):
pform = self._print(expr.expr)
if ((expr.point and any(p != 0 for p in expr.point)) or
len(expr.variables) > 1):
if expr.point and any(p != 0 for p in expr.point):
pform = prettyForm(*pform.right('; '))
if len(expr.variables) > 1:
pform = prettyForm(*pform.right(self._print(expr.variables)))
else:
pform = prettyForm(*pform.right(self._print(expr.variables[0])))
pform = prettyForm(*pform.right(self._print(expr.variables[0])))
if self._use_unicode:
pform = prettyForm(*pform.right(' \N{RIGHTWARDS ARROW} '))
else:
pform = prettyForm(*pform.right(' -> '))
if len(expr.point) > 1:
pform = prettyForm(*pform.right(self._print(expr.point)))
else:
pform = prettyForm(*pform.right(self._print(expr.point[0])))
pform = prettyForm(*pform.right(self._print(expr.point[0])))
pform = prettyForm(*pform.parens())
pform = prettyForm(*pform.left('O'))
return pform
Expand Down
4 changes: 1 addition & 3 deletions diofant/printing/str.py
Expand Up @@ -273,9 +273,7 @@ def _print_NegativeInfinity(self, expr):

def _print_Order(self, expr):
if all(p == 0 for p in expr.point) or not expr.variables:
if len(expr.variables) <= 1:
return f'O({self._print(expr.expr)})'
return f"O({self.stringify((expr.expr,) + expr.variables, ', ', 0)})"
return f'O({self._print(expr.expr)})'
return f"O({self.stringify(expr.args, ', ', 0)})"

def _print_Cycle(self, expr):
Expand Down
74 changes: 2 additions & 72 deletions diofant/tests/calculus/test_order.py
Expand Up @@ -2,8 +2,7 @@

from diofant import (Add, Derivative, Function, I, Integer, Integral, O,
Rational, Symbol, conjugate, cos, digamma, exp, expand,
factorial, ln, log, nan, oo, pi, sin, sqrt, symbols,
transpose)
factorial, ln, log, nan, oo, pi, sin, sqrt, transpose)
from diofant.abc import w, x, y, z


Expand All @@ -22,8 +21,6 @@ def test_free_symbols():
assert O(1).free_symbols == set()
assert O(x).free_symbols == {x}
assert O(1, x).free_symbols == {x}
assert O(x*y).free_symbols == {x, y}
assert O(x, x, y).free_symbols == {x, y}


def test_simple_1():
Expand All @@ -32,7 +29,6 @@ def test_simple_1():
assert O(x)*3 == O(x)
assert -28*O(x) == O(x)
assert O(O(x)) == O(x)
assert O(O(x), y) == O(O(x), x, y)
assert O(-23) == O(1)
assert O(exp(x)) == O(1, x)
assert O(exp(1/x)).expr == exp(1/x)
Expand All @@ -41,7 +37,6 @@ def test_simple_1():
assert O(x**(5*o/3)).expr == x**(5*o/3)
assert O(x**2 + x + y, x) == O(1, x)
assert O(x**2 + x + y, y) == O(1, y)
pytest.raises(ValueError, lambda: O(exp(x), x, x))
pytest.raises(TypeError, lambda: O(x, 2 - x))
pytest.raises(ValueError, lambda: O(x, (x, x**2)))

Expand Down Expand Up @@ -131,13 +126,6 @@ def test_contains():
assert not O(exp(1/x)).contains(O(exp(2/x)))

assert O(x).contains(O(y)) is None
assert O(x).contains(O(y*x))
assert O(y*x).contains(O(x))
assert O(y).contains(O(x*y))
assert O(x).contains(O(y**2*x))

assert O(x*y**2).contains(O(x**2*y)) is None
assert O(x**2*y).contains(O(x*y**2)) is None

assert O(sin(1/x**2)).contains(O(cos(1/x**2))) is None
assert O(cos(1/x**2)).contains(O(sin(1/x**2))) is None
Expand All @@ -149,6 +137,7 @@ def test_contains():
assert O(1, x) not in O(1)
assert O(1) in O(1, x)
pytest.raises(TypeError, lambda: O(x*y**2) in O(x**2*y))
pytest.raises(TypeError, lambda: O(x**y, x) in O(x**z, x))


def test_add_1():
Expand All @@ -164,50 +153,6 @@ def test_add_1():
def test_ln_args():
assert O(log(x)) + O(log(2*x)) == O(log(x))
assert O(log(x)) + O(log(x**3)) == O(log(x))
assert O(log(x*y)) + O(log(x) + log(y)) == O(log(x*y))


def test_multivar_0():
assert O(x*y).expr == x*y
assert O(x*y**2).expr == x*y**2
assert O(x*y, x).expr == x
assert O(x*y**2, y).expr == y**2
assert O(x*y*z).expr == x*y*z
assert O(x/y).expr == x/y
assert O(x*exp(1/y)).expr == x*exp(1/y)
assert O(exp(x)*exp(1/y)).expr == exp(1/y)


def test_multivar_0a():
assert O(exp(1/x)*exp(1/y)).expr == exp(1/x + 1/y)


def test_multivar_1():
assert O(x + y).expr == x + y
assert O(x + 2*y).expr == x + y
assert (O(x + y) + x).expr == (x + y)
assert (O(x + y) + x**2) == O(x + y)
assert (O(x + y) + 1/x) == 1/x + O(x + y)
assert O(x**2 + y*x).expr == x**2 + y*x


def test_multivar_2():
assert O(x**2*y + y**2*x, x, y).expr == x**2*y + y**2*x


def test_multivar_mul_1():
assert O(x + y)*x == O(x**2 + y*x, x, y)


def test_multivar_3():
assert (O(x) + O(y)).args in [
(O(x), O(y)),
(O(y), O(x))]
assert O(x) + O(y) + O(x + y) == O(x + y)
assert (O(x**2*y) + O(y**2*x)).args in [
(O(x*y**2), O(y*x**2)),
(O(y*x**2), O(x*y**2))]
assert (O(x**2*y) + O(y*x)) == O(x*y)


def test_sympyissue_3468():
Expand Down Expand Up @@ -246,7 +191,6 @@ def test_order_leadterm():

def test_order_symbols():
e = x*y*sin(x)*Integral(x, (x, 1, 2))
assert O(e) == O(x**2*y)
assert O(e, x) == O(x**2)


Expand Down Expand Up @@ -302,16 +246,6 @@ def test_eval():
assert (O(1)**x).is_Pow


def test_sympyissue_4279():
a, b = symbols('a b')
assert O(a, a, b) + O(1, a, b) == O(1, a, b)
assert O(b, a, b) + O(1, a, b) == O(1, a, b)
assert O(a + b) + O(1, a, b) == O(1, a, b)
assert O(1, a, b) + O(a, a, b) == O(1, a, b)
assert O(1, a, b) + O(b, a, b) == O(1, a, b)
assert O(1, a, b) + O(a + b) == O(1, a, b)


def test_sympyissue_4855():
assert 1/O(1) != O(1)
assert 1/O(x) != O(1/x)
Expand Down Expand Up @@ -359,8 +293,6 @@ def test_order_at_infinity():
assert O(3*x, (x, oo)) == O(x, (x, oo))
assert O(x, (x, oo))*3 == O(x, (x, oo))
assert -28*O(x, (x, oo)) == O(x, (x, oo))
assert O(O(x, (x, oo)), (x, oo)) == O(x, (x, oo))
assert O(O(x, (x, oo)), (y, oo)) == O(x, (x, oo), (y, oo))
assert O(3, (x, oo)) == O(1, (x, oo))
assert O(x**2 + x + y, (x, oo)) == O(x**2, (x, oo))
assert O(x**2 + x + y, (y, oo)) == O(y, (y, oo))
Expand Down Expand Up @@ -432,8 +364,6 @@ def test_order_subs_limits():
assert O(x**2).subs({x: y - 1}) == O((y - 1)**2, (y, 1))
assert O(10*x**2, (x, 2)).subs({x: y - 1}) == O(1, (y, 3))

assert O(x).subs({x: y*z}) == O(y*z, y, z)

assert O(1/x, (x, oo)).subs({x: +I*x}) == O(1/x, (x, -I*oo))
assert O(1/x, (x, oo)).subs({x: -I*x}) == O(1/x, (x, +I*oo))

Expand Down
2 changes: 1 addition & 1 deletion diofant/tests/core/test_args.py
Expand Up @@ -1432,7 +1432,7 @@ def test_diofant__calculus__limits__Limit():


def test_diofant__calculus__order__Order():
assert _test_args(Order(1, x, y))
assert _test_args(Order(1, x))


def test_diofant__simplify__hyperexpand__Hyper_Function():
Expand Down

0 comments on commit ca68bb1

Please sign in to comment.