Skip to content

Commit

Permalink
Fix comparisons that don't return Booleanable objects (+ add 3.7 to t…
Browse files Browse the repository at this point in the history
…ravis)

Should fix #48.
  • Loading branch information
danthedeckie committed Oct 15, 2018
1 parent 355c60b commit 9176a78
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
19 changes: 9 additions & 10 deletions simpleeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
- charlax (Charles-Axel Dein charlax) Makefile and cleanups
- mommothazaz123 (Andrew Zhu) f"string" support
- lubieowoce (Uryga) various potential vulnerabilities
- JCavallo (Jean Cavallo) names dict shouldn't be modified
- JCavallo (Jean Cavallo) names dict shouldn't be modified
-------------------------------------
Expand Down Expand Up @@ -339,14 +339,15 @@ def _eval_boolop(self, node):
return False

def _eval_compare(self, node):
left = self._eval(node.left)
right = self._eval(node.left)
to_return = True
for operation, comp in zip(node.ops, node.comparators):
if not to_return:
break
left = right
right = self._eval(comp)
if self.operators[type(operation)](left, right):
left = right # Hi Dr. Seuss...
else:
return False
return True
to_return = self.operators[type(operation)](left, right)
return to_return

def _eval_ifexp(self, node):
return self._eval(node.body) if self._eval(node.test) \
Expand Down Expand Up @@ -488,7 +489,6 @@ def eval(self, expr):
self._max_count = 0
return super(EvalWithCompoundTypes, self).eval(expr)


def _eval_dict(self, node):
return {self._eval(k): self._eval(v)
for (k, v) in zip(node.keys, node.values)}
Expand Down Expand Up @@ -539,7 +539,7 @@ def do_generator(gi=0):
raise IterableTooLong('Comprehension generates too many elements')
recurse_targets(g.target, i)
if all(self._eval(iff) for iff in g.ifs):
if len(node.generators) > gi + 1 :
if len(node.generators) > gi + 1:
do_generator(gi+1)
else:
to_return.append(self._eval(node.elt))
Expand All @@ -551,7 +551,6 @@ def do_generator(gi=0):
return to_return



def simple_eval(expr, operators=None, functions=None, names=None):
""" Simply evaluate an expresssion """
s = SimpleEval(operators=operators,
Expand Down
63 changes: 63 additions & 0 deletions test_simpleeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,5 +898,68 @@ def test_attributedoesnotexist(self):
assert hasattr(e, 'expression')
assert getattr(e, 'expression') == 'foo in bar'

class TestUnusualComparisons(DRYTest):
def test_custom_comparison_returner(self):
class Blah(object):
def __gt__(self, other):
return self

b = Blah()
self.s.names = {'b': b}
self.t('b > 2', b)

def test_custom_comparison_doesnt_return_boolable(self):
"""
SqlAlchemy, bless it's cotton socks, returns BinaryExpression objects
when asking for comparisons between things. These BinaryExpressions
raise a TypeError if you try and check for Truthyiness.
"""
class BinaryExpression(object):
def __init__(self, value):
self.value = value
def __eq__(self, other):
return self.value == getattr(other, 'value', other)
def __repr__(self):
return '<BinaryExpression:{}>'.format(self.value)
def __bool__(self):
# This is the only important part, to match SqlAlchemy - the rest
# of the methods are just to make testing a bit easier...
raise TypeError("Boolean value of this clause is not defined")

class Blah(object):
def __gt__(self, other):
return BinaryExpression('GT')
def __lt__(self, other):
return BinaryExpression('LT')

b = Blah()
self.s.names = {'b': b}
e = eval('b > 2', self.s.names)

self.t('b > 2', BinaryExpression('GT'))
self.t('1 < 5 > b', BinaryExpression('LT'))

class TestShortCircuiting(DRYTest):
def test_shortcircuit_if(self):
x = []
self.s.functions = {'foo':lambda y:x.append(y)}
self.t('foo(1) if foo(2) else foo(3)', None)
self.assertListEqual(x, [2, 3])

x = []
self.t('42 if True else foo(99)', 42)
self.assertListEqual(x, [])

def test_shortcircuit_comparison(self):
x = []
self.s.functions = {'foo': lambda y:x.append(y)}
with self.assertRaises(TypeError):
self.t('foo(11) < 12', False)
self.assertListEqual(x, [11])
x = []

self.t('1 > 2 < foo(22)', False)
self.assertListEqual(x, [])

if __name__ == '__main__': # pragma: no cover
unittest.main()

0 comments on commit 9176a78

Please sign in to comment.