From e02e8c097e39002cd5156674614613e7252fba90 Mon Sep 17 00:00:00 2001 From: Sergey B Kirpichev Date: Sat, 19 Mar 2022 06:20:57 +0300 Subject: [PATCH 1/3] core: drop unused class attribute for WildFunction --- diofant/core/function.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/diofant/core/function.py b/diofant/core/function.py index ee0d2448ec8..62bb91721d7 100644 --- a/diofant/core/function.py +++ b/diofant/core/function.py @@ -29,7 +29,6 @@ import collections import inspect -import typing import mpmath import mpmath.libmp as mlib @@ -714,8 +713,6 @@ class WildFunction(Function, AtomicExpr): """ - include: set[typing.Any] = set() - def __init__(self, name, **assumptions): from ..sets.sets import FiniteSet, Set self.name = name From 0aef791fce9a65eb6e5353fc90e9ef99ba5f738d Mon Sep 17 00:00:00 2001 From: Sergey B Kirpichev Date: Fri, 18 Mar 2022 15:30:44 +0300 Subject: [PATCH 2/3] functions: set default value for Piecewise to nan (if not provided) --- diofant/functions/elementary/piecewise.py | 21 +++++++++----------- diofant/printing/ccode.py | 8 -------- diofant/printing/fcode.py | 8 -------- diofant/printing/latex.py | 9 +-------- diofant/printing/octave.py | 8 -------- diofant/tests/functions/test_piecewise.py | 23 +++++++++++----------- diofant/tests/printing/test_ccode.py | 3 --- diofant/tests/printing/test_fcode.py | 3 --- diofant/tests/printing/test_lambdarepr.py | 8 ++++---- diofant/tests/printing/test_latex.py | 4 ++-- diofant/tests/printing/test_octave.py | 3 --- diofant/tests/solvers/test_inequalities.py | 2 +- diofant/tests/test_wester.py | 4 ++-- 13 files changed, 30 insertions(+), 74 deletions(-) diff --git a/diofant/functions/elementary/piecewise.py b/diofant/functions/elementary/piecewise.py index 15abbbd0585..834586ccce1 100644 --- a/diofant/functions/elementary/piecewise.py +++ b/diofant/functions/elementary/piecewise.py @@ -1,5 +1,5 @@ from ...core import (Basic, Dummy, Equality, Expr, Function, Integer, Tuple, - diff, oo) + diff, nan, oo) from ...core.relational import Relational from ...logic import And, Not, Or, false, to_cnf, true from ...logic.boolalg import Boolean @@ -81,6 +81,9 @@ def __new__(cls, *args, **options): if cond == true: break + if cond != true: + newargs.append(ExprCondPair(nan, true)) + if options.pop('evaluate', True): r = cls.eval(*newargs) else: @@ -282,7 +285,7 @@ def _sort_expr_cond(self, sym, a, b, targetcond=None): independent_expr_cond = [] if isinstance(targetcond, Relational) and targetcond.has(sym): targetcond = solve_univariate_inequality(targetcond, sym) - for expr, cond in self.args: + for expr, cond in self.args: # pragma: no branch if isinstance(cond, Relational) and cond.has(sym): cond = solve_univariate_inequality(cond, sym) if isinstance(cond, Or): @@ -292,7 +295,7 @@ def _sort_expr_cond(self, sym, a, b, targetcond=None): expr_cond.append((expr, cond)) if cond == true: break - for expr, cond in expr_cond: + for expr, cond in expr_cond: # pragma: no branch if cond == true: independent_expr_cond.append((expr, cond)) default = self.func(*independent_expr_cond) @@ -404,10 +407,6 @@ def _sort_expr_cond(self, sym, a, b, targetcond=None): int_expr.extend(holes) if targetcond == true: return [(h[0], h[1], None) for h in holes] - elif holes and default is None: - raise ValueError('Called interval evaluation over piecewise ' # noqa: SFS101 - 'function on undefined intervals %s' % - ', '.join([str((h[0], h[1])) for h in holes])) return int_expr @@ -422,7 +421,7 @@ def _eval_power(self, other): def _eval_subs(self, old, new): """Piecewise conditions may contain bool which are not of Basic type.""" args = list(self.args) - for i, (e, c) in enumerate(args): + for i, (e, c) in enumerate(args): # pragma: no branch c = c._subs(old, new) if c != false: e = e._subs(old, new) @@ -430,8 +429,6 @@ def _eval_subs(self, old, new): if c == true: return self.func(*args) - return self.func(*args) - def _eval_transpose(self): return self.func(*[(e.transpose(), c) for e, c in self.args]) @@ -517,9 +514,9 @@ def piecewise_fold(expr): Examples ======== - >>> p = Piecewise((x, x < 1), (1, x >= 1)) + >>> p = Piecewise((x, x < 1), (1, True)) >>> piecewise_fold(x*p) - Piecewise((x**2, x < 1), (x, x >= 1)) + Piecewise((x**2, x < 1), (x, true)) See Also ======== diff --git a/diofant/printing/ccode.py b/diofant/printing/ccode.py index c485c183979..4e84e4fa994 100644 --- a/diofant/printing/ccode.py +++ b/diofant/printing/ccode.py @@ -181,14 +181,6 @@ def _print_NegativeInfinity(self, expr): return '-HUGE_VAL' def _print_Piecewise(self, expr): - if expr.args[-1].cond != true: - # We need the last conditional to be a True, otherwise the resulting - # function may not return a result. - raise ValueError('All Piecewise expressions must contain an ' - '(expr, True) statement to be used as a default ' - 'condition. Without one, the generated ' - 'expression may not evaluate to anything under ' - 'some condition.') lines = [] if expr.has(Assignment): for i, (e, c) in enumerate(expr.args): diff --git a/diofant/printing/fcode.py b/diofant/printing/fcode.py index 9e8824f7ccf..66c02550c75 100644 --- a/diofant/printing/fcode.py +++ b/diofant/printing/fcode.py @@ -130,14 +130,6 @@ def _get_loop_opening_ending(self, indices): return open_lines, close_lines def _print_Piecewise(self, expr): - if expr.args[-1].cond != true: - # We need the last conditional to be a True, otherwise the resulting - # function may not return a result. - raise ValueError('All Piecewise expressions must contain an ' - '(expr, True) statement to be used as a default ' - 'condition. Without one, the generated ' - 'expression may not evaluate to anything under ' - 'some condition.') lines = [] if expr.has(Assignment): for i, (e, c) in enumerate(expr.args): diff --git a/diofant/printing/latex.py b/diofant/printing/latex.py index da880e40ec1..02358fd467f 100644 --- a/diofant/printing/latex.py +++ b/diofant/printing/latex.py @@ -16,7 +16,6 @@ from ..core.function import _coeff_isneg from ..core.operations import AssocOp from ..core.relational import Relational -from ..logic import true from ..utilities import default_sort_key, has_variety from .conventions import requires_partial, split_super_sub from .precedence import PRECEDENCE, precedence @@ -1236,13 +1235,7 @@ def _print_Relational(self, expr): def _print_Piecewise(self, expr): ecpairs = [r'%s & \text{for}\: %s' % (self._print(e), self._print(c)) for e, c in expr.args[:-1]] - if expr.args[-1].cond == true: - ecpairs.append(r'%s & \text{otherwise}' % - self._print(expr.args[-1].expr)) - else: - ecpairs.append(r'%s & \text{for}\: %s' % - (self._print(expr.args[-1].expr), - self._print(expr.args[-1].cond))) + ecpairs.append(r'%s & \text{otherwise}' % self._print(expr.args[-1].expr)) tex = r'\begin{cases} %s \end{cases}' return tex % r' \\'.join(ecpairs) diff --git a/diofant/printing/octave.py b/diofant/printing/octave.py index 94bae07c546..45e67db4073 100644 --- a/diofant/printing/octave.py +++ b/diofant/printing/octave.py @@ -379,14 +379,6 @@ def _print_zeta(self, expr): return self._print_not_supported(expr) def _print_Piecewise(self, expr): - if expr.args[-1].cond != true: - # We need the last conditional to be a True, otherwise the resulting - # function may not return a result. - raise ValueError('All Piecewise expressions must contain an ' - '(expr, True) statement to be used as a default ' - 'condition. Without one, the generated ' - 'expression may not evaluate to anything under ' - 'some condition.') lines = [] if self._settings['inline']: # Express each (cond, expr) pair in a nested Horner form: diff --git a/diofant/tests/functions/test_piecewise.py b/diofant/tests/functions/test_piecewise.py index c4c2bbd1db5..f73ef702d71 100644 --- a/diofant/tests/functions/test_piecewise.py +++ b/diofant/tests/functions/test_piecewise.py @@ -3,8 +3,8 @@ from diofant import (And, Basic, Eq, Function, Gt, I, Integral, Interval, Max, Min, Not, O, Or, Piecewise, Rational, Symbol, adjoint, conjugate, cos, diff, exp, expand, integrate, lambdify, - log, oo, pi, piecewise_fold, sin, solve, symbols, sympify, - transpose) + log, nan, oo, pi, piecewise_fold, sin, solve, symbols, + sympify, transpose) from diofant.abc import a, t, x, y @@ -14,7 +14,6 @@ def test_piecewise(): - # Test canonization assert Piecewise((x, x < 1), (0, True)) == Piecewise((x, x < 1), (0, True)) assert Piecewise((x, x < 1), (0, True), (1, True)) == \ @@ -34,8 +33,8 @@ def test_piecewise(): assert Piecewise((0, Eq(z, 0, evaluate=False)), (1, True)) == 1 # Test subs - p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), x >= 0)) - p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), x**2 >= 0)) + p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), True)) + p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), True)) assert p.subs({x: x**2}) == p_x2 assert p.subs({x: -5}) == -1 assert p.subs({x: -1}) == 1 @@ -81,7 +80,7 @@ def test_piecewise(): # Test differentiation f = x fp = x*p - dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, x >= 0)) + dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, True)) fp_dx = x*dp + p assert diff(p, x) == dp assert diff(f*p, x) == fp_dx @@ -94,7 +93,7 @@ def test_piecewise(): assert p - dp == -(dp - p) # Test power - dp2 = Piecewise((0, x < -1), (4*x**2, x < 0), (1/x**2, x >= 0)) + dp2 = Piecewise((0, x < -1), (4*x**2, x < 0), (1/x**2, True)) assert dp**2 == dp2 # Test _eval_interval @@ -119,7 +118,7 @@ def test_piecewise(): assert peval4._eval_interval(x, -1, 1) == 2 # Test integration - p_int = Piecewise((-x, x < -1), (x**3/3.0, x < 0), (-x + x*log(x), x >= 0)) + p_int = Piecewise((-x, x < -1), (x**3/3.0, x < 0), (-x + x*log(x), True)) assert integrate(p, x) == p_int p = Piecewise((x, x < 1), (x**2, -1 <= x), (x, 3 < x)) assert integrate(p, (x, -2, 2)) == 5.0/6.0 @@ -127,7 +126,7 @@ def test_piecewise(): p = Piecewise((0, x < 0), (1, x < 1), (0, x < 2), (1, x < 3), (0, True)) assert integrate(p, (x, -oo, oo)) == 2 p = Piecewise((x, x < -10), (x**2, x <= -1), (x, 1 < x)) - pytest.raises(ValueError, lambda: integrate(p, (x, -2, 2))) + assert integrate(p, (x, -2, 2)) is nan # Test commutativity assert p.is_commutative is True @@ -446,8 +445,8 @@ def test_piecewise_as_leading_term(): def test_piecewise_complex(): - p1 = Piecewise((2, x < 0), (1, 0 <= x)) - p2 = Piecewise((2*I, x < 0), (I, 0 <= x)) + p1 = Piecewise((2, x < 0), (1, True)) + p2 = Piecewise((2*I, x < 0), (I, True)) p3 = Piecewise((I*x, x > 1), (1 + I, True)) p4 = Piecewise((-I*conjugate(x), x > 1), (1 - I, True)) @@ -487,7 +486,7 @@ def test_piecewise_evaluate(): def test_as_expr_set_pairs(): - assert Piecewise((x, x > 0), (-x, x <= 0)).as_expr_set_pairs() == \ + assert Piecewise((x, x > 0), (-x, True)).as_expr_set_pairs() == \ [(x, Interval(0, oo, True)), (-x, Interval(-oo, 0))] assert Piecewise(((x - 2)**2, x >= 0), (0, True)).as_expr_set_pairs() == \ diff --git a/diofant/tests/printing/test_ccode.py b/diofant/tests/printing/test_ccode.py index f0753b87280..311a5cd4080 100644 --- a/diofant/tests/printing/test_ccode.py +++ b/diofant/tests/printing/test_ccode.py @@ -181,9 +181,6 @@ def test_ccode_Piecewise(): 'else {\n' ' c = pow(x, 2);\n' '}') - # Check that Piecewise without a True (default) condition error - expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) - pytest.raises(ValueError, lambda: ccode(expr)) def test_ccode_Piecewise_deep(): diff --git a/diofant/tests/printing/test_fcode.py b/diofant/tests/printing/test_fcode.py index 483b5872aea..8ef0231c73f 100644 --- a/diofant/tests/printing/test_fcode.py +++ b/diofant/tests/printing/test_fcode.py @@ -379,9 +379,6 @@ def test_fcode_Piecewise(): code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95) expected = ' merge(x, merge(x**2, sin(x), x > 1), x < 1)' assert code == expected - # Check that Piecewise without a True (default) condition error - expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) - pytest.raises(ValueError, lambda: fcode(expr)) assert (fcode(Piecewise((0, x < -1), (1, And(x >= -1, x < 0)), (-1, True)), assign_to='var') == diff --git a/diofant/tests/printing/test_lambdarepr.py b/diofant/tests/printing/test_lambdarepr.py index 47de10eb416..5184d0174fa 100644 --- a/diofant/tests/printing/test_lambdarepr.py +++ b/diofant/tests/printing/test_lambdarepr.py @@ -45,7 +45,7 @@ def test_piecewise(): p = Piecewise((x, x < 0)) l = lambdarepr(p) eval(h + l) # pylint: disable=eval-used - assert l == '((x) if (x < 0) else None)' + assert l == '((x) if (x < 0) else (((nan) if (True) else None)))' p = Piecewise( (1, x < 1), @@ -63,7 +63,7 @@ def test_piecewise(): ) l = lambdarepr(p) eval(h + l) # pylint: disable=eval-used - assert l == '((1) if (x < 1) else (((2) if (x < 2) else None)))' + assert l == '((1) if (x < 1) else (((2) if (x < 2) else (((nan) if (True) else None)))))' p = Piecewise( (x, x < 1), @@ -93,8 +93,8 @@ def test_piecewise(): ) l = lambdarepr(p) eval(h + l) # pylint: disable=eval-used - assert l == '((x**2) if (x < 0) else (((x) if (((x >= 0) and ' \ - '(x < 1))) else (((-x + 2) if (x >= 1) else None)))))' + assert l == ('((x**2) if (x < 0) else (((x) if (((x >= 0) and (x < 1))) ' + 'else (((-x + 2) if (x >= 1) else (((nan) if (True) else None)))))))') p = Piecewise( (1, x >= 1), diff --git a/diofant/tests/printing/test_latex.py b/diofant/tests/printing/test_latex.py index acd940a68f8..fbfb2ee55a7 100644 --- a/diofant/tests/printing/test_latex.py +++ b/diofant/tests/printing/test_latex.py @@ -813,8 +813,8 @@ def test_latex_Piecewise(): assert latex(p, itex=True) == r'\begin{cases} x & \text{for}\: x \lt 1 \\x^{2} &' \ r' \text{otherwise} \end{cases}' p = Piecewise((x, x < 0), (0, x >= 0)) - assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 &' \ - r' \text{for}\: x \geq 0 \end{cases}' + assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 & '\ + r'\text{for}\: x \geq 0 \\\mathrm{NaN} & \text{otherwise} \end{cases}' A, B = symbols('A B', commutative=False) p = Piecewise((A**2, Eq(A, B)), (A*B, True)) s = r'\begin{cases} A^{2} & \text{for}\: A = B \\A B & \text{otherwise} \end{cases}' diff --git a/diofant/tests/printing/test_octave.py b/diofant/tests/printing/test_octave.py index d68c69f8c8d..78ee520d30c 100644 --- a/diofant/tests/printing/test_octave.py +++ b/diofant/tests/printing/test_octave.py @@ -232,9 +232,6 @@ def test_octave_piecewise(): 'else\n' ' r = x.^5;\n' 'end') - # Check that Piecewise without a True (default) condition error - expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) - pytest.raises(ValueError, lambda: octave_code(expr)) def test_octave_piecewise_times_const(): diff --git a/diofant/tests/solvers/test_inequalities.py b/diofant/tests/solvers/test_inequalities.py index c3a269ae6f3..a816943c81e 100644 --- a/diofant/tests/solvers/test_inequalities.py +++ b/diofant/tests/solvers/test_inequalities.py @@ -253,7 +253,7 @@ def test_reduce_piecewise_inequalities(): # sympy/sympy#10255 assert reduce_inequalities(Piecewise((1, x < 1), (3, True)) > 1) == \ Le(1, x) - assert reduce_inequalities(Piecewise((x**2, x < 0), (2*x, x >= 0)) < 1) == \ + assert reduce_inequalities(Piecewise((x**2, x < 0), (2*x, True)) < 1) == \ And(Lt(-1, x), x < Rational(1, 2)) diff --git a/diofant/tests/test_wester.py b/diofant/tests/test_wester.py index 5e4ba83f9e4..09b7bdb2508 100644 --- a/diofant/tests/test_wester.py +++ b/diofant/tests/test_wester.py @@ -1716,8 +1716,8 @@ def test_U1(): def test_U2(): - f = Lambda(x, Piecewise((-x, x < 0), (x, x >= 0))) - assert diff(f(x), x) == Piecewise((-1, x < 0), (1, x >= 0)) + f = Lambda(x, Piecewise((-x, x < 0), (x, True))) + assert diff(f(x), x) == Piecewise((-1, x < 0), (1, True)) def test_U3(): From d73cbe09b0cb96fa7f4eb975ada158b258040043 Mon Sep 17 00:00:00 2001 From: Sergey B Kirpichev Date: Sat, 19 Mar 2022 08:20:06 +0300 Subject: [PATCH 3/3] core: add a different workaround for Piecewise in is_constant() --- diofant/core/expr.py | 11 +++++++---- diofant/tests/core/test_expr.py | 4 ++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/diofant/core/expr.py b/diofant/core/expr.py index 5871995aa4f..9230a625618 100644 --- a/diofant/core/expr.py +++ b/diofant/core/expr.py @@ -437,6 +437,8 @@ def is_constant(self, *wrt, **flags): True """ + from ..functions import Piecewise + simplify = flags.get('simplify', True) # Except for expressions that contain units, only one of these should @@ -485,7 +487,7 @@ def is_constant(self, *wrt, **flags): if a is None or a is nan: # try random real a = expr._random(None, -1, 0, 1, 0) - except ZeroDivisionError: + except (ZeroDivisionError, TypeError): a = None if a is not None and a is not nan: try: @@ -494,7 +496,7 @@ def is_constant(self, *wrt, **flags): if b is nan: # evaluation may succeed when substitution fails b = expr._random(None, 1, 0, 1, 0) - except ZeroDivisionError: + except (ZeroDivisionError, TypeError): b = None if b is not None and b is not nan and b.equals(a) is False: return False @@ -512,7 +514,7 @@ def is_constant(self, *wrt, **flags): deriv = expr.diff(w) if simplify: deriv = deriv.simplify() - if deriv != 0: + if deriv: if not (deriv.is_Number or pure_complex(deriv)): if flags.get('failing_number', False): return failing_number @@ -520,7 +522,8 @@ def is_constant(self, *wrt, **flags): assert deriv.free_symbols return # dead line provided _random returns None in such cases return False - return True + if not expr.has(Piecewise): + return True def equals(self, other, failing_expression=False): """Return True if self == other, False if it doesn't, or None. If diff --git a/diofant/tests/core/test_expr.py b/diofant/tests/core/test_expr.py index 675160807ee..9cb10411a26 100644 --- a/diofant/tests/core/test_expr.py +++ b/diofant/tests/core/test_expr.py @@ -1457,6 +1457,10 @@ def test_is_constant(): assert Integer(2).is_constant() is True + for _ in range(5): + assert Piecewise((x, (x < 0)), (0, True)).is_constant() is not True + assert Piecewise((1, (x < 0)), (0, True)).is_constant() is not True + def test_equals(): assert (-3 - sqrt(5) + (-sqrt(10)/2 - sqrt(2)/2)**2).equals(0)