Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Expression support comparison operators. #4962

Merged
merged 3 commits into from Dec 14, 2018
Merged
Changes from 2 commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -339,16 +339,36 @@ def describe(self, handler=print):
handler(s)

def __lt__(self, other):
raise NotImplementedError("'<' comparison with expression of type {}".format(str(self._type)))
other = to_expr(other)

This comment has been minimized.

Copy link
@tpoterba

tpoterba Dec 13, 2018

Collaborator

Now this stuff is all the same -- can we abstract to:

def _comparison_op(self, op, other):
	other = to_expr(other)
    left, right, success = unify_exprs(self, other)
    if not success:
        raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions "
                        f"of type '{self.dtype}' and '{other.dtype}'")
    return left._bin_op(op, right, hl.tbool)

Then we can use this in these new methods, and in in __eq__ and __ne__.

It'll just look like:

    def __lt__(self, other):
        return self._comparison_op('!=', other)

This comment has been minimized.

Copy link
@tongda

tongda Dec 13, 2018

Author Contributor

Ah, yeah, totally agree.

left, right, success = unify_exprs(self, other)
if not success:
raise TypeError(f"Invalid '<' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
return left._bin_op("<", right, hl.tbool)

def __le__(self, other):
raise NotImplementedError("'<=' comparison with expression of type {}".format(str(self._type)))
other = to_expr(other)
left, right, success = unify_exprs(self, other)
if not success:
raise TypeError(f"Invalid '<=' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
return left._bin_op("<=", right, hl.tbool)

def __gt__(self, other):
raise NotImplementedError("'>' comparison with expression of type {}".format(str(self._type)))
other = to_expr(other)
left, right, success = unify_exprs(self, other)
if not success:
raise TypeError(f"Invalid '>' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
return left._bin_op(">", right, hl.tbool)

def __ge__(self, other):
raise NotImplementedError("'>=' comparison with expression of type {}".format(str(self._type)))
other = to_expr(other)
left, right, success = unify_exprs(self, other)
if not success:
raise TypeError(f"Invalid '>=' comparison, cannot compare expressions "
f"of type '{self.dtype}' and '{other.dtype}'")
return left._bin_op(">=", right, hl.tbool)

def __nonzero__(self):
raise ExpressionException(
@@ -101,7 +101,23 @@ def test_operators(self):
x33=(kt.a == 0) & (kt.b == 5),
x34=(kt.a == 0) | (kt.b == 5),
x35=False,
x36=True
x36=True,
x37=kt.e > "helln",
x38=kt.e < "hellp",
x39=kt.e <= "hello",
x40=kt.e >= "hello",
x41="helln" > kt.e,
x42="hellp" < kt.e,
x43="hello" >= kt.e,
x44="hello" <= kt.e,
x45=kt.f > [1, 2],
x46=kt.f < [1, 3],
x47=kt.f >= [1, 2, 3],
x48=kt.f <= [1, 2, 3],
x49=kt.f < [1.0, 2.0],
x50=kt.f > [1.0, 3.0],
x51=[1.0, 2.0, 3.0] <= kt.f,
x52=[1.0, 2.0, 3.0] >= kt.f,

This comment has been minimized.

Copy link
@tpoterba

tpoterba Dec 13, 2018

Collaborator

can we add another example:

x53=hl.tuple([True, 1.0]) < (1, 0)

(should be False)

This comment has been minimized.

Copy link
@tongda

tongda Dec 13, 2018

Author Contributor

Well, great case. Because it triggered some tricky thing. Exception thrown.

    def _compare_op(self, op, other):
        other = to_expr(other)
        left, right, success = unify_exprs(self, other)
        if not success:
>           raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions "
                            f"of type '{self.dtype}' and '{other.dtype}'")
E           TypeError: Invalid '<' comparison, cannot compare expressions of type 'tuple(bool, float64)' and 'tuple(int32, int32)'

Since bool < int32 < float64, I expected the coerced type should be tuple(int32, float64). The TupleCoercer seems not work as expected.

This comment has been minimized.

Copy link
@tpoterba

tpoterba Dec 13, 2018

Collaborator

Ah! Boolean -> int32 conversion is something we support explicitly for arithmetic, but not everywhere. I'd like to think through more cases before we open that up.

Let's make the test case

x53=hl.tuple([1, 1.0]) < (1.0, 0)

This comment has been minimized.

Copy link
@tongda

tongda Dec 13, 2018

Author Contributor

I think I can hotfix the can_coerce check with some ugly code:

    def can_coerce(self, t: HailType):
        import pdb; pdb.set_trace()
        if self.elements is None:
            return isinstance(t, ttuple)
        else:
            return (isinstance(t, ttuple)
                    and len(t.types) == len(self.elements)
                    and all(c.can_coerce(t_) or coercer_from_dtype(t_).can_coerce(hl.dtype(c.str_t)) for c, t_ in zip(self.elements, t.types)))

but i have not found a way to fix the _coerce method.

This comment has been minimized.

Copy link
@tpoterba

tpoterba Dec 13, 2018

Collaborator

ah wait, hmm...

This comment has been minimized.

Copy link
@tpoterba

tpoterba Dec 13, 2018

Collaborator

bool is totally in the coercer system. let me look.

This comment has been minimized.

Copy link
@tpoterba

tpoterba Dec 13, 2018

Collaborator

the problem is in unify_exprs:

    for t in types:
        c = expressions.coercer_from_dtype(t)
        if all(c.can_coerce(e.dtype) for e in exprs):
            return tuple([c.coerce(e) for e in exprs]) + (True,)

This is looking for one type that can coerce the rest. The correct thing is to recursively walk the structure of the types (tuple, struct, etc) and determine the coercer for each element. This is trickier and shouldn't be a part of this PR.

Instead, let's just change the check to:

x53=hl.tuple([1, 1.0]) < (1.0, 0.0)

and save the unify_exprs fix for another PR.

This comment has been minimized.

Copy link
@tongda

tongda Dec 14, 2018

Author Contributor

sure. np.

).take(1)[0])

expected = {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3],
@@ -115,7 +131,13 @@ def test_operators(self):
'x24': True, 'x25': False, 'x26': True,
'x27': False, 'x28': True, 'x29': False,
'x30': False, 'x31': True, 'x32': False,
'x33': False, 'x34': False, 'x35': False, 'x36': True}
'x33': False, 'x34': False, 'x35': False,
'x36': True, 'x37': True, 'x38': True,
'x39': True, 'x40': True, 'x41': False,
'x42': False, 'x43': True, 'x44': True,
'x45': True, 'x46': True, 'x47': True,
'x48': True, 'x49': False, 'x50': False,
'x51': True, 'x52': True}

for k, v in expected.items():
if isinstance(v, float):
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.