Skip to content

Commit

Permalink
functions: set default value for Piecewise to nan if not provided
Browse files Browse the repository at this point in the history
  • Loading branch information
skirpichev committed Mar 19, 2022
1 parent 15fe853 commit e92fc13
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 33 deletions.
9 changes: 6 additions & 3 deletions diofant/functions/elementary/piecewise.py
@@ -1,5 +1,5 @@
from ...core import (Basic, Dummy, Equality, Expr, Function, Integer, Tuple,
diff, oo)
diff, nan, oo)
from ...core.relational import Relational
from ...logic import And, Not, Or, false, to_cnf, true
from ...logic.boolalg import Boolean
Expand Down Expand Up @@ -81,6 +81,9 @@ def __new__(cls, *args, **options):
if cond == true:
break

if cond != true:
newargs.append(ExprCondPair(nan, true))

if options.pop('evaluate', True):
r = cls.eval(*newargs)
else:
Expand Down Expand Up @@ -517,9 +520,9 @@ def piecewise_fold(expr):
Examples
========
>>> p = Piecewise((x, x < 1), (1, x >= 1))
>>> p = Piecewise((x, x < 1), (1, True))
>>> piecewise_fold(x*p)
Piecewise((x**2, x < 1), (x, x >= 1))
Piecewise((x**2, x < 1), (x, true))
See Also
========
Expand Down
23 changes: 11 additions & 12 deletions diofant/tests/functions/test_piecewise.py
Expand Up @@ -3,8 +3,8 @@
from diofant import (And, Basic, Eq, Function, Gt, I, Integral, Interval, Max,
Min, Not, O, Or, Piecewise, Rational, Symbol, adjoint,
conjugate, cos, diff, exp, expand, integrate, lambdify,
log, oo, pi, piecewise_fold, sin, solve, symbols, sympify,
transpose)
log, nan, oo, pi, piecewise_fold, sin, solve, symbols,
sympify, transpose)
from diofant.abc import a, t, x, y


Expand All @@ -14,7 +14,6 @@


def test_piecewise():

# Test canonization
assert Piecewise((x, x < 1), (0, True)) == Piecewise((x, x < 1), (0, True))
assert Piecewise((x, x < 1), (0, True), (1, True)) == \
Expand All @@ -34,8 +33,8 @@ def test_piecewise():
assert Piecewise((0, Eq(z, 0, evaluate=False)), (1, True)) == 1

# Test subs
p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), x >= 0))
p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), x**2 >= 0))
p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), True))
p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), True))
assert p.subs({x: x**2}) == p_x2
assert p.subs({x: -5}) == -1
assert p.subs({x: -1}) == 1
Expand Down Expand Up @@ -81,7 +80,7 @@ def test_piecewise():
# Test differentiation
f = x
fp = x*p
dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, x >= 0))
dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, True))
fp_dx = x*dp + p
assert diff(p, x) == dp
assert diff(f*p, x) == fp_dx
Expand All @@ -94,7 +93,7 @@ def test_piecewise():
assert p - dp == -(dp - p)

# Test power
dp2 = Piecewise((0, x < -1), (4*x**2, x < 0), (1/x**2, x >= 0))
dp2 = Piecewise((0, x < -1), (4*x**2, x < 0), (1/x**2, True))
assert dp**2 == dp2

# Test _eval_interval
Expand All @@ -119,15 +118,15 @@ def test_piecewise():
assert peval4._eval_interval(x, -1, 1) == 2

# Test integration
p_int = Piecewise((-x, x < -1), (x**3/3.0, x < 0), (-x + x*log(x), x >= 0))
p_int = Piecewise((-x, x < -1), (x**3/3.0, x < 0), (-x + x*log(x), True))
assert integrate(p, x) == p_int
p = Piecewise((x, x < 1), (x**2, -1 <= x), (x, 3 < x))
assert integrate(p, (x, -2, 2)) == 5.0/6.0
assert integrate(p, (x, 2, -2)) == -5.0/6.0
p = Piecewise((0, x < 0), (1, x < 1), (0, x < 2), (1, x < 3), (0, True))
assert integrate(p, (x, -oo, oo)) == 2
p = Piecewise((x, x < -10), (x**2, x <= -1), (x, 1 < x))
pytest.raises(ValueError, lambda: integrate(p, (x, -2, 2)))
assert integrate(p, (x, -2, 2)) is nan

# Test commutativity
assert p.is_commutative is True
Expand Down Expand Up @@ -446,8 +445,8 @@ def test_piecewise_as_leading_term():


def test_piecewise_complex():
p1 = Piecewise((2, x < 0), (1, 0 <= x))
p2 = Piecewise((2*I, x < 0), (I, 0 <= x))
p1 = Piecewise((2, x < 0), (1, True))
p2 = Piecewise((2*I, x < 0), (I, True))
p3 = Piecewise((I*x, x > 1), (1 + I, True))
p4 = Piecewise((-I*conjugate(x), x > 1), (1 - I, True))

Expand Down Expand Up @@ -487,7 +486,7 @@ def test_piecewise_evaluate():


def test_as_expr_set_pairs():
assert Piecewise((x, x > 0), (-x, x <= 0)).as_expr_set_pairs() == \
assert Piecewise((x, x > 0), (-x, True)).as_expr_set_pairs() == \
[(x, Interval(0, oo, True)), (-x, Interval(-oo, 0))]

assert Piecewise(((x - 2)**2, x >= 0), (0, True)).as_expr_set_pairs() == \
Expand Down
3 changes: 0 additions & 3 deletions diofant/tests/printing/test_ccode.py
Expand Up @@ -181,9 +181,6 @@ def test_ccode_Piecewise():
'else {\n'
' c = pow(x, 2);\n'
'}')
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
pytest.raises(ValueError, lambda: ccode(expr))


def test_ccode_Piecewise_deep():
Expand Down
3 changes: 0 additions & 3 deletions diofant/tests/printing/test_fcode.py
Expand Up @@ -379,9 +379,6 @@ def test_fcode_Piecewise():
code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95)
expected = ' merge(x, merge(x**2, sin(x), x > 1), x < 1)'
assert code == expected
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
pytest.raises(ValueError, lambda: fcode(expr))

assert (fcode(Piecewise((0, x < -1), (1, And(x >= -1, x < 0)),
(-1, True)), assign_to='var') ==
Expand Down
8 changes: 4 additions & 4 deletions diofant/tests/printing/test_lambdarepr.py
Expand Up @@ -45,7 +45,7 @@ def test_piecewise():
p = Piecewise((x, x < 0))
l = lambdarepr(p)
eval(h + l) # pylint: disable=eval-used
assert l == '((x) if (x < 0) else None)'
assert l == '((x) if (x < 0) else (((nan) if (True) else None)))'

p = Piecewise(
(1, x < 1),
Expand All @@ -63,7 +63,7 @@ def test_piecewise():
)
l = lambdarepr(p)
eval(h + l) # pylint: disable=eval-used
assert l == '((1) if (x < 1) else (((2) if (x < 2) else None)))'
assert l == '((1) if (x < 1) else (((2) if (x < 2) else (((nan) if (True) else None)))))'

p = Piecewise(
(x, x < 1),
Expand Down Expand Up @@ -93,8 +93,8 @@ def test_piecewise():
)
l = lambdarepr(p)
eval(h + l) # pylint: disable=eval-used
assert l == '((x**2) if (x < 0) else (((x) if (((x >= 0) and ' \
'(x < 1))) else (((-x + 2) if (x >= 1) else None)))))'
assert l == ('((x**2) if (x < 0) else (((x) if (((x >= 0) and (x < 1))) '
'else (((-x + 2) if (x >= 1) else (((nan) if (True) else None)))))))')

p = Piecewise(
(1, x >= 1),
Expand Down
4 changes: 2 additions & 2 deletions diofant/tests/printing/test_latex.py
Expand Up @@ -813,8 +813,8 @@ def test_latex_Piecewise():
assert latex(p, itex=True) == r'\begin{cases} x & \text{for}\: x \lt 1 \\x^{2} &' \
r' \text{otherwise} \end{cases}'
p = Piecewise((x, x < 0), (0, x >= 0))
assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 &' \
r' \text{for}\: x \geq 0 \end{cases}'
assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 & '\
r'\text{for}\: x \geq 0 \\\mathrm{NaN} & \text{otherwise} \end{cases}'
A, B = symbols('A B', commutative=False)
p = Piecewise((A**2, Eq(A, B)), (A*B, True))
s = r'\begin{cases} A^{2} & \text{for}\: A = B \\A B & \text{otherwise} \end{cases}'
Expand Down
3 changes: 0 additions & 3 deletions diofant/tests/printing/test_octave.py
Expand Up @@ -232,9 +232,6 @@ def test_octave_piecewise():
'else\n'
' r = x.^5;\n'
'end')
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
pytest.raises(ValueError, lambda: octave_code(expr))


def test_octave_piecewise_times_const():
Expand Down
2 changes: 1 addition & 1 deletion diofant/tests/solvers/test_inequalities.py
Expand Up @@ -253,7 +253,7 @@ def test_reduce_piecewise_inequalities():
# sympy/sympy#10255
assert reduce_inequalities(Piecewise((1, x < 1), (3, True)) > 1) == \
Le(1, x)
assert reduce_inequalities(Piecewise((x**2, x < 0), (2*x, x >= 0)) < 1) == \
assert reduce_inequalities(Piecewise((x**2, x < 0), (2*x, True)) < 1) == \
And(Lt(-1, x), x < Rational(1, 2))


Expand Down
4 changes: 2 additions & 2 deletions diofant/tests/test_wester.py
Expand Up @@ -1716,8 +1716,8 @@ def test_U1():


def test_U2():
f = Lambda(x, Piecewise((-x, x < 0), (x, x >= 0)))
assert diff(f(x), x) == Piecewise((-1, x < 0), (1, x >= 0))
f = Lambda(x, Piecewise((-x, x < 0), (x, True)))
assert diff(f(x), x) == Piecewise((-1, x < 0), (1, True))


def test_U3():
Expand Down

0 comments on commit e92fc13

Please sign in to comment.