Skip to content

Commit

Permalink
Improve rich comparison with unknown types
Browse files Browse the repository at this point in the history
Return NotImplemented when a rich comparison is unable to make sense of
the input types. This way, the python interpreter is able to delegate to
reflected methods as appropriate in order to maintain symmetric
relations.

This commit also removes direct invocations of rich comparison
methods, like __eq__, except in super calls, explicitly reflected calls,
and other cases where the order is essential (see, e.g.,
sympy#7951). These calls may sidestep the desired dispatching and
introduce bugs by evaluating, e.g., not NotImplemented to False.

Closes sympy#13078.
  • Loading branch information
danielwe committed Aug 5, 2017
1 parent d132081 commit 4edfd3c
Show file tree
Hide file tree
Showing 24 changed files with 203 additions and 80 deletions.
6 changes: 3 additions & 3 deletions sympy/core/basic.py
Expand Up @@ -313,7 +313,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other
return NotImplemented

if type(self) != type(other):
return False
Expand All @@ -329,7 +329,7 @@ def __ne__(self, other):
but faster
"""
return not self.__eq__(other)
return not self == other

def dummy_eq(self, other, symbol=None):
"""
Expand Down Expand Up @@ -1180,7 +1180,7 @@ def _has(self, pattern):

def _has_matcher(self):
"""Helper for .has()"""
return self.__eq__
return lambda other: self == other

def replace(self, query, value, map=False, simultaneous=True, exact=False):
"""
Expand Down
8 changes: 4 additions & 4 deletions sympy/core/expr.py
Expand Up @@ -248,7 +248,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
for me in (self, other):
if (me.is_complex and me.is_real is False) or \
me.has(S.ComplexInfinity):
Expand All @@ -270,7 +270,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
for me in (self, other):
if (me.is_complex and me.is_real is False) or \
me.has(S.ComplexInfinity):
Expand All @@ -292,7 +292,7 @@ def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
for me in (self, other):
if (me.is_complex and me.is_real is False) or \
me.has(S.ComplexInfinity):
Expand All @@ -314,7 +314,7 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
for me in (self, other):
if (me.is_complex and me.is_real is False) or \
me.has(S.ComplexInfinity):
Expand Down
4 changes: 2 additions & 2 deletions sympy/core/exprtools.py
Expand Up @@ -797,7 +797,7 @@ def __eq__(self, other): # Factors
return self.factors == other.factors

def __ne__(self, other): # Factors
return not self.__eq__(other)
return not self == other


class Term(object):
Expand Down Expand Up @@ -909,7 +909,7 @@ def __eq__(self, other): # Term
self.denom == other.denom)

def __ne__(self, other): # Term
return not self.__eq__(other)
return not self == other


def _gcd_terms(terms, isprimitive=False, fraction=True):
Expand Down
70 changes: 35 additions & 35 deletions sympy/core/numbers.py
Expand Up @@ -697,30 +697,30 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
raise NotImplementedError('%s needs .__lt__() method' %
(self.__class__.__name__))

def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
raise NotImplementedError('%s needs .__le__() method' %
(self.__class__.__name__))

def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
return _sympify(other).__lt__(self)

def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
return _sympify(other).__le__(self)

def __hash__(self):
Expand Down Expand Up @@ -1258,7 +1258,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other --> not ==
return NotImplemented
if isinstance(other, NumberSymbol):
if other.is_irrational:
return False
Expand All @@ -1276,13 +1276,13 @@ def __eq__(self, other):
return False # Float != non-Number

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__le__(self)
if other.is_comparable:
Expand All @@ -1296,7 +1296,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__lt__(self)
if other.is_comparable:
Expand All @@ -1310,7 +1310,7 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__ge__(self)
if other.is_real and other.is_number:
Expand All @@ -1324,7 +1324,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__gt__(self)
if other.is_real and other.is_number:
Expand Down Expand Up @@ -1719,7 +1719,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other --> not ==
return NotImplemented
if isinstance(other, NumberSymbol):
if other.is_irrational:
return False
Expand All @@ -1734,13 +1734,13 @@ def __eq__(self, other):
return False

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__le__(self)
expr = self
Expand All @@ -1758,7 +1758,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__lt__(self)
expr = self
Expand All @@ -1776,7 +1776,7 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
if isinstance(other, NumberSymbol):
return other.__ge__(self)
expr = self
Expand All @@ -1794,7 +1794,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
expr = self
if isinstance(other, NumberSymbol):
return other.__gt__(self)
Expand Down Expand Up @@ -2112,13 +2112,13 @@ def __eq__(self, other):
return Rational.__eq__(self, other)

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
if isinstance(other, Integer):
return _sympify(self.p > other.p)
return Rational.__gt__(self, other)
Expand All @@ -2127,7 +2127,7 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
if isinstance(other, Integer):
return _sympify(self.p < other.p)
return Rational.__lt__(self, other)
Expand All @@ -2136,7 +2136,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
if isinstance(other, Integer):
return _sympify(self.p >= other.p)
return Rational.__ge__(self, other)
Expand All @@ -2145,7 +2145,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
if isinstance(other, Integer):
return _sympify(self.p <= other.p)
return Rational.__le__(self, other)
Expand Down Expand Up @@ -2840,7 +2840,7 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
if other.is_real:
return S.false
return Expr.__lt__(self, other)
Expand All @@ -2849,7 +2849,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
if other.is_real:
if other.is_finite or other is S.NegativeInfinity:
return S.false
Expand All @@ -2863,7 +2863,7 @@ def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
if other.is_real:
if other.is_finite or other is S.NegativeInfinity:
return S.true
Expand All @@ -2877,7 +2877,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
if other.is_real:
return S.true
return Expr.__ge__(self, other)
Expand Down Expand Up @@ -3061,7 +3061,7 @@ def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
if other.is_real:
if other.is_finite or other is S.Infinity:
return S.true
Expand All @@ -3075,7 +3075,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
if other.is_real:
return S.true
return Expr.__le__(self, other)
Expand All @@ -3084,7 +3084,7 @@ def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
if other.is_real:
return S.false
return Expr.__gt__(self, other)
Expand All @@ -3093,7 +3093,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
if other.is_real:
if other.is_finite or other is S.Infinity:
return S.false
Expand Down Expand Up @@ -3339,7 +3339,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other --> not ==
return NotImplemented
if self is other:
return True
if isinstance(other, Number) and self.is_irrational:
Expand All @@ -3348,13 +3348,13 @@ def __eq__(self, other):
return False # NumberSymbol != non-(Number|self)

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
return NotImplemented
if self is other:
return S.false
if isinstance(other, Number):
Expand All @@ -3375,7 +3375,7 @@ def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
return NotImplemented
if self is other:
return S.true
if other.is_real and other.is_number:
Expand All @@ -3388,7 +3388,7 @@ def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
return NotImplemented
r = _sympify((-self) < (-other))
if r in (S.true, S.false):
return r
Expand All @@ -3399,7 +3399,7 @@ def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
return NotImplemented
r = _sympify((-self) <= (-other))
if r in (S.true, S.false):
return r
Expand Down

0 comments on commit 4edfd3c

Please sign in to comment.