Skip to content

Commit

Permalink
[ysh semantics] Comparison operators take both Int and Float
Browse files Browse the repository at this point in the history
Consistent with + - * /

Part of #1710.

Rename method
  • Loading branch information
Andy C committed Aug 25, 2023
1 parent 7fb6667 commit 1487623
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 73 deletions.
16 changes: 7 additions & 9 deletions spec/ysh-expr-compare.test.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## oils_failures_allowed: 2
## oils_failures_allowed: 1

#### Exact equality with === and !==
shopt -s oil:all
Expand Down Expand Up @@ -157,17 +157,15 @@ si i true
sf f true
f sf true
---
i f 6.0
si f 6.0
i sf 6.0
i f false
si f false
i sf false
---
f i 2.5
sf i 2.5
f si 2.5
f i true
sf i true
f si true
## END



#### Comparison of Int
shopt -s oil:upgrade

Expand Down
98 changes: 34 additions & 64 deletions ysh/expr_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def EvalExpr(self, node, blame_loc):

# Note: IndexError and KeyError are handled in more specific places

def _ValueToInteger(self, val):
def _ConvertToInt(self, val):
# type: (value_t) -> int
UP_val = val
with tagswitch(val) as case:
Expand All @@ -274,24 +274,6 @@ def _ValueToInteger(self, val):

raise error.InvalidType2(val, 'Expected Int', loc.Missing)

def _ValueToNumber(self, val):
# type: (value_t) -> value_t
"""If val looks like Int or Float, convert it to that type.
Otherwise return it untouched.
"""
UP_val = val
with tagswitch(val) as case:
if case(value_e.Str):
val = cast(value.Str, UP_val)
if match.LooksLikeInteger(val.s):
return value.Int(int(val.s))

if match.LooksLikeFloat(val.s):
return value.Float(float(val.s))

return val

def _EvalConst(self, node):
# type: (expr.Const) -> value_t

Expand Down Expand Up @@ -471,8 +453,8 @@ def _ArithNumeric(self, left, right, op_id):

def _ArithBitwise(self, left, right, op):
# type: (value_t, value_t, Id_t) -> value.Int
left_i = self._ValueToInteger(left)
right_i = self._ValueToInteger(right)
left_i = self._ConvertToInt(left)
right_i = self._ConvertToInt(right)

if op == Id.Arith_Amp:
return value.Int(left_i & right_i)
Expand Down Expand Up @@ -546,22 +528,22 @@ def _EvalBinary(self, node):

# Everything below has 2 integer operands
if op_id == Id.Expr_DSlash: # a // b
left_i = self._ValueToInteger(left)
right_i = self._ValueToInteger(right)
left_i = self._ConvertToInt(left)
right_i = self._ConvertToInt(right)
if right_i == 0:
raise error.Expr('Divide by zero', node.op)
return value.Int(left_i // right_i)

if op_id == Id.Arith_Percent: # a % b
left_i = self._ValueToInteger(left)
right_i = self._ValueToInteger(right)
left_i = self._ConvertToInt(left)
right_i = self._ConvertToInt(right)
if right_i == 0:
raise error.Expr('Divide by zero', node.op)
return value.Int(left_i % right_i)

if op_id == Id.Arith_DStar: # a ** b
left_i = self._ValueToInteger(left)
right_i = self._ValueToInteger(right)
left_i = self._ConvertToInt(left)
right_i = self._ConvertToInt(right)

# Same as sh_expr_eval.py
if right_i < 0:
Expand Down Expand Up @@ -629,50 +611,39 @@ def _EvalRange(self, node):

def _CompareNumeric(self, left, right, op):
# type: (value_t, value_t, Token) -> bool
left = self._ValueToNumber(left)
right = self._ValueToNumber(right)
UP_left = left
UP_right = right

if left.tag() != right.tag():
raise error.InvalidType3(
left, right, 'Comparison expected the same type', op)
c, i1, i2, f1, f2 = self._ConvertForBinaryOp(left, right)

op_id = op.id
with tagswitch(left) as case:
if case(value_e.Int):
left = cast(value.Int, UP_left)
right = cast(value.Int, UP_right)
if op_id == Id.Arith_Less:
return left.i < right.i
elif op_id == Id.Arith_Great:
return left.i > right.i
elif op_id == Id.Arith_LessEqual:
return left.i <= right.i
elif op_id == Id.Arith_GreatEqual:
return left.i >= right.i
if c == coerced_e.Int:
with switch(op_id) as case:
if case(Id.Arith_Less):
return i1 < i2
elif case(Id.Arith_Great):
return i1 > i2
elif case(Id.Arith_LessEqual):
return i1 <= i2
elif case(Id.Arith_GreatEqual):
return i1 >= i2
else:
raise AssertionError()

elif case(value_e.Float):
left = cast(value.Float, UP_left)
right = cast(value.Float, UP_right)
if op_id == Id.Arith_Less:
return left.f < right.f
elif op_id == Id.Arith_Great:
return left.f > right.f
elif op_id == Id.Arith_LessEqual:
return left.f <= right.f
elif op_id == Id.Arith_GreatEqual:
return left.f >= right.f
elif c == coerced_e.Float:
with switch(op_id) as case:
if case(Id.Arith_Less):
return f1 < f2
elif case(Id.Arith_Great):
return f1 > f2
elif case(Id.Arith_LessEqual):
return f1 <= f2
elif case(Id.Arith_GreatEqual):
return f1 >= f2
else:
raise AssertionError()

else:
raise error.InvalidType2(
left, 'Comparison expected Int or Float', op)

raise AssertionError() # silence C++ compiler
else:
raise error.InvalidType(
'Comparison operator expected numbers, got %s and %s' %
(ui.ValType(left), ui.ValType(right)), op)

def _EvalCompare(self, node):
# type: (expr.Compare) -> value_t
Expand Down Expand Up @@ -1414,5 +1385,4 @@ def EvalRegex(self, node):
print()
return new_node


# vim: sw=4

0 comments on commit 1487623

Please sign in to comment.