Skip to content

Commit

Permalink
core: add a different workaround for Piecewise in is_constant()
Browse files Browse the repository at this point in the history
  • Loading branch information
skirpichev committed Mar 20, 2022
1 parent 0aef791 commit d73cbe0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 7 additions & 4 deletions diofant/core/expr.py
Expand Up @@ -437,6 +437,8 @@ def is_constant(self, *wrt, **flags):
True
"""
from ..functions import Piecewise

simplify = flags.get('simplify', True)

# Except for expressions that contain units, only one of these should
Expand Down Expand Up @@ -485,7 +487,7 @@ def is_constant(self, *wrt, **flags):
if a is None or a is nan:
# try random real
a = expr._random(None, -1, 0, 1, 0)
except ZeroDivisionError:
except (ZeroDivisionError, TypeError):
a = None
if a is not None and a is not nan:
try:
Expand All @@ -494,7 +496,7 @@ def is_constant(self, *wrt, **flags):
if b is nan:
# evaluation may succeed when substitution fails
b = expr._random(None, 1, 0, 1, 0)
except ZeroDivisionError:
except (ZeroDivisionError, TypeError):
b = None
if b is not None and b is not nan and b.equals(a) is False:
return False
Expand All @@ -512,15 +514,16 @@ def is_constant(self, *wrt, **flags):
deriv = expr.diff(w)
if simplify:
deriv = deriv.simplify()
if deriv != 0:
if deriv:
if not (deriv.is_Number or pure_complex(deriv)):
if flags.get('failing_number', False):
return failing_number
else:
assert deriv.free_symbols
return # dead line provided _random returns None in such cases
return False
return True
if not expr.has(Piecewise):
return True

def equals(self, other, failing_expression=False):
"""Return True if self == other, False if it doesn't, or None. If
Expand Down
4 changes: 4 additions & 0 deletions diofant/tests/core/test_expr.py
Expand Up @@ -1457,6 +1457,10 @@ def test_is_constant():

assert Integer(2).is_constant() is True

for _ in range(5):
assert Piecewise((x, (x < 0)), (0, True)).is_constant() is not True
assert Piecewise((1, (x < 0)), (0, True)).is_constant() is not True


def test_equals():
assert (-3 - sqrt(5) + (-sqrt(10)/2 - sqrt(2)/2)**2).equals(0)
Expand Down

0 comments on commit d73cbe0

Please sign in to comment.