Make Expression support comparison operators.#4962
Make Expression support comparison operators.#4962danking merged 3 commits intohail-is:masterfrom tongda:master
Conversation
|
|
||
| def __lt__(self, other): | ||
| raise NotImplementedError("'<' comparison with expression of type {}".format(str(self._type))) | ||
| return self._bin_op("<", other, hl.tbool) |
There was a problem hiding this comment.
We need to check types -- it's possible to generate invalid computations which the backend won't like:
hl.literal('abc') < hl.struct(foo=1)
The solution is to copy and modify this bit from __eq__:
other = to_expr(other)
left, right, success = unify_exprs(self, other)
if not success:
raise TypeError(f"Invalid '==' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
return left._bin_op("==", right, tbool)
not only does this check types, it also unifies types where possible, e.g. tuple<bool, float32> and tuple<int32, int32> become comparable by unifying both to a tuple<int32, float32>.
There was a problem hiding this comment.
I see. I will submit the change. Thanks!
tpoterba
left a comment
There was a problem hiding this comment.
one more round of changes
|
|
||
| def __lt__(self, other): | ||
| raise NotImplementedError("'<' comparison with expression of type {}".format(str(self._type))) | ||
| other = to_expr(other) |
There was a problem hiding this comment.
Now this stuff is all the same -- can we abstract to:
def _comparison_op(self, op, other):
other = to_expr(other)
left, right, success = unify_exprs(self, other)
if not success:
raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
return left._bin_op(op, right, hl.tbool)Then we can use this in these new methods, and in in __eq__ and __ne__.
It'll just look like:
def __lt__(self, other):
return self._comparison_op('!=', other)
There was a problem hiding this comment.
Ah, yeah, totally agree.
| x49=kt.f < [1.0, 2.0], | ||
| x50=kt.f > [1.0, 3.0], | ||
| x51=[1.0, 2.0, 3.0] <= kt.f, | ||
| x52=[1.0, 2.0, 3.0] >= kt.f, |
There was a problem hiding this comment.
can we add another example:
x53=hl.tuple([True, 1.0]) < (1, 0)
(should be False)
There was a problem hiding this comment.
Well, great case. Because it triggered some tricky thing. Exception thrown.
def _compare_op(self, op, other):
other = to_expr(other)
left, right, success = unify_exprs(self, other)
if not success:
> raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
E TypeError: Invalid '<' comparison, cannot compare expressions of type 'tuple(bool, float64)' and 'tuple(int32, int32)'
Since bool < int32 < float64, I expected the coerced type should be tuple(int32, float64). The TupleCoercer seems not work as expected.
There was a problem hiding this comment.
Ah! Boolean -> int32 conversion is something we support explicitly for arithmetic, but not everywhere. I'd like to think through more cases before we open that up.
Let's make the test case
x53=hl.tuple([1, 1.0]) < (1.0, 0)
There was a problem hiding this comment.
I think I can hotfix the can_coerce check with some ugly code:
def can_coerce(self, t: HailType):
import pdb; pdb.set_trace()
if self.elements is None:
return isinstance(t, ttuple)
else:
return (isinstance(t, ttuple)
and len(t.types) == len(self.elements)
and all(c.can_coerce(t_) or coercer_from_dtype(t_).can_coerce(hl.dtype(c.str_t)) for c, t_ in zip(self.elements, t.types)))
but i have not found a way to fix the _coerce method.
There was a problem hiding this comment.
bool is totally in the coercer system. let me look.
There was a problem hiding this comment.
the problem is in unify_exprs:
for t in types:
c = expressions.coercer_from_dtype(t)
if all(c.can_coerce(e.dtype) for e in exprs):
return tuple([c.coerce(e) for e in exprs]) + (True,)
This is looking for one type that can coerce the rest. The correct thing is to recursively walk the structure of the types (tuple, struct, etc) and determine the coercer for each element. This is trickier and shouldn't be a part of this PR.
Instead, let's just change the check to:
x53=hl.tuple([1, 1.0]) < (1.0, 0.0)
and save the unify_exprs fix for another PR.
|
Thanks for the contribution! |
|
Cool! Thank you for your patience. Hopefully I can contribute more. |
By implement
__gt__,__lt__,__ge__and__le__, Expressions can support comparison operators. Code like below is valid now: