Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Expression support comparison operators. #4962

Merged
merged 3 commits into from Dec 14, 2018
Merged

Make Expression support comparison operators. #4962

merged 3 commits into from Dec 14, 2018

Conversation

tongda
Copy link
Contributor

@tongda tongda commented Dec 13, 2018

By implement __gt__, __lt__, __ge__ and __le__, Expressions can support comparison operators. Code like below is valid now:

>>> hl.eval(hl.literal("abc") > "a")
True

@tpoterba tpoterba self-assigned this Dec 13, 2018
@@ -339,16 +339,16 @@ def describe(self, handler=print):
handler(s)

def __lt__(self, other):
raise NotImplementedError("'<' comparison with expression of type {}".format(str(self._type)))
return self._bin_op("<", other, hl.tbool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check types -- it's possible to generate invalid computations which the backend won't like:

hl.literal('abc') < hl.struct(foo=1)

The solution is to copy and modify this bit from __eq__:

        other = to_expr(other)
        left, right, success = unify_exprs(self, other)
        if not success:
            raise TypeError(f"Invalid '==' comparison, cannot compare expressions "
                            f"of type '{self.dtype}' and '{other.dtype}'")
        return left._bin_op("==", right, tbool)

not only does this check types, it also unifies types where possible, e.g. tuple<bool, float32> and tuple<int32, int32> become comparable by unifying both to a tuple<int32, float32>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will submit the change. Thanks!

Copy link
Contributor

@tpoterba tpoterba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more round of changes

@@ -339,16 +339,36 @@ def describe(self, handler=print):
handler(s)

def __lt__(self, other):
raise NotImplementedError("'<' comparison with expression of type {}".format(str(self._type)))
other = to_expr(other)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this stuff is all the same -- can we abstract to:

def _comparison_op(self, op, other):
	other = to_expr(other)
    left, right, success = unify_exprs(self, other)
    if not success:
        raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions "
                        f"of type '{self.dtype}' and '{other.dtype}'")
    return left._bin_op(op, right, hl.tbool)

Then we can use this in these new methods, and in in __eq__ and __ne__.

It'll just look like:

    def __lt__(self, other):
        return self._comparison_op('!=', other)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yeah, totally agree.

x49=kt.f < [1.0, 2.0],
x50=kt.f > [1.0, 3.0],
x51=[1.0, 2.0, 3.0] <= kt.f,
x52=[1.0, 2.0, 3.0] >= kt.f,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add another example:

x53=hl.tuple([True, 1.0]) < (1, 0)

(should be False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, great case. Because it triggered some tricky thing. Exception thrown.

    def _compare_op(self, op, other):
        other = to_expr(other)
        left, right, success = unify_exprs(self, other)
        if not success:
>           raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions "
                            f"of type '{self.dtype}' and '{other.dtype}'")
E           TypeError: Invalid '<' comparison, cannot compare expressions of type 'tuple(bool, float64)' and 'tuple(int32, int32)'

Since bool < int32 < float64, I expected the coerced type should be tuple(int32, float64). The TupleCoercer seems not work as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Boolean -> int32 conversion is something we support explicitly for arithmetic, but not everywhere. I'd like to think through more cases before we open that up.

Let's make the test case

x53=hl.tuple([1, 1.0]) < (1.0, 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can hotfix the can_coerce check with some ugly code:

    def can_coerce(self, t: HailType):
        import pdb; pdb.set_trace()
        if self.elements is None:
            return isinstance(t, ttuple)
        else:
            return (isinstance(t, ttuple)
                    and len(t.types) == len(self.elements)
                    and all(c.can_coerce(t_) or coercer_from_dtype(t_).can_coerce(hl.dtype(c.str_t)) for c, t_ in zip(self.elements, t.types)))

but i have not found a way to fix the _coerce method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah wait, hmm...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool is totally in the coercer system. let me look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is in unify_exprs:

    for t in types:
        c = expressions.coercer_from_dtype(t)
        if all(c.can_coerce(e.dtype) for e in exprs):
            return tuple([c.coerce(e) for e in exprs]) + (True,)

This is looking for one type that can coerce the rest. The correct thing is to recursively walk the structure of the types (tuple, struct, etc) and determine the coercer for each element. This is trickier and shouldn't be a part of this PR.

Instead, let's just change the check to:

x53=hl.tuple([1, 1.0]) < (1.0, 0.0)

and save the unify_exprs fix for another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. np.

@danking danking merged commit 542d6bd into hail-is:master Dec 14, 2018
@tpoterba
Copy link
Contributor

Thanks for the contribution!

@tongda
Copy link
Contributor Author

tongda commented Dec 14, 2018

Cool! Thank you for your patience. Hopefully I can contribute more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants