Skip to content

Commit

Permalink
Fix #13045 - Comparison of Structs and Struct Exprs (#13226)
Browse files Browse the repository at this point in the history
RR: #13045
RR: #13046 
Support symmetric comparison of structs and struct expressions.
Provide better error messages when attempting to construct literals from
expressions with free variables.
  • Loading branch information
ehigham committed Jul 7, 2023
1 parent b7cc5f3 commit 07c4930
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 27 deletions.
15 changes: 14 additions & 1 deletion hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,17 @@ def literal(x: Any, dtype: Optional[Union[HailType, str]] = None):
-------
:class:`.Expression`
"""
wrapper = {'has_expr': False}
wrapper = {'has_expr': False, 'has_free_vars': False}

def typecheck_expr(t, x):
if isinstance(x, Expression):
wrapper['has_expr'] = True
wrapper['has_free_vars'] |= (
builtins.len(x._ir.free_vars) > 0 or
builtins.len(x._ir.free_agg_vars) > 0 or
builtins.len(x._ir.free_scan_vars) > 0
)

if x.dtype != t:
raise TypeError(f"'literal': type mismatch: expected '{t}', found '{x.dtype}'")
elif x._indices.source is not None:
Expand Down Expand Up @@ -265,6 +271,13 @@ def typecheck_expr(t, x):
raise TypeError("'literal': object did not match the passed type '{}'"
.format(dtype)) from e

if wrapper['has_free_vars']:
raise ValueError(
"'literal' cannot be used with hail expressions that depend "
"on other expressions. Use expression 'x' directly "
"instead of passing it to 'literal'."
)

if wrapper['has_expr']:
return literal(hl.eval(to_expr(x, dtype)), dtype)

Expand Down
6 changes: 3 additions & 3 deletions hail/python/hail/genetics/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def __repr__(self):
return 'Call(alleles=%s, phased=%s)' % (self._alleles, self._phased)

def __eq__(self, other):
return (isinstance(other, Call)
and self._phased == other._phased
and self._alleles == other._alleles)
return ( self._phased == other._phased and
self._alleles == other._alleles
) if isinstance(other, Call) else NotImplemented

def __hash__(self):
return hash(self._phased) ^ hash(tuple(self._alleles))
Expand Down
8 changes: 4 additions & 4 deletions hail/python/hail/genetics/locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def __repr__(self):
return 'Locus(contig=%s, position=%s, reference_genome=%s)' % (self.contig, self.position, self._rg)

def __eq__(self, other):
return (isinstance(other, Locus)
and self._contig == other._contig
and self._position == other._position
and self._rg == other._rg)
return ( self._contig == other._contig and
self._position == other._position and
self._rg == other._rg
) if isinstance(other, Locus) else NotImplemented

def __hash__(self):
return hash(self._contig) ^ hash(self._position) ^ hash(self._rg)
Expand Down
11 changes: 6 additions & 5 deletions hail/python/hail/utils/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def __repr__(self):
.format(repr(self.start), repr(self.end), repr(self.includes_start), repr(self._includes_end))

def __eq__(self, other):
return (isinstance(other, Interval)
and self._start == other._start
and self._end == other._end
and self._includes_start == other._includes_start
and self._includes_end == other._includes_end)
return ( self._start == other._start and
self._end == other._end and
self._includes_start == other._includes_start and
self._includes_end == other._includes_end
) if isinstance(other, Interval) else NotImplemented


def __hash__(self):
return hash(self._start) ^ hash(self._end) ^ hash(self._includes_start) ^ hash(self._includes_end)
Expand Down
8 changes: 5 additions & 3 deletions hail/python/hail/utils/linkedlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def __iter__(self):
return ListIterator(self.node)

def __str__(self):
return 'List({})'.format(', '.join(str(x) for x in self))
return f'''List({', '.join(str(x) for x in self)})'''

def __repr__(self):
return 'List({})'.format(', '.join(repr(x) for x in self))
return f'''List({', '.join(repr(x) for x in self)})'''

def __eq__(self, other):
return isinstance(other, LinkedList) and list(self) == list(other)
return list(self) == list(other) \
if isinstance(other, LinkedList) \
else NotImplemented

def __ne__(self, other):
return not self.__eq__(other)
Expand Down
4 changes: 3 additions & 1 deletion hail/python/hail/utils/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __str__(self):
)

def __eq__(self, other):
return isinstance(other, Struct) and self._fields == other._fields
return self._fields == other._fields \
if isinstance(other, Struct) \
else NotImplemented

def __hash__(self):
return 37 + hash(tuple(sorted(self._fields.items())))
Expand Down
7 changes: 7 additions & 0 deletions hail/python/test/hail/expr/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def test_array():

with pytest.raises(ValueError, match='array: only one dimensional ndarrays are supported: ndarray<float64, 2>'):
hl.eval(hl.array(hl.nd.array([[1.0], [2.0]])))


def test_literal_free_vars():
"Give better error messages in response to code written by ChatGPT"
array = hl.literal([1, 2, 3])
with pytest.raises(ValueError, match='expressions that depend on other expressions'):
array.map(hl.literal)
7 changes: 7 additions & 0 deletions hail/python/test/hail/genetics/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,10 @@ def test_zeroploid(self):
"Calls with greater than 2 alleles are not supported.",
Call,
[1, 1, 1, 1])

def test_call_rich_comparison():
val = Call([0, 0])
expr = hl.call(0, 0)

assert hl.eval(val == expr)
assert hl.eval(expr == val)
23 changes: 13 additions & 10 deletions hail/python/test/hail/genetics/test_locus.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import unittest

from hail.genetics import Locus
import hail as hl
from hail.genetics import *
from ..helpers import *

class Tests(unittest.TestCase):
def test_constructor():
l = Locus.parse('1:100')

assert l == Locus('1', 100)
assert l == Locus(1, 100)
assert l.reference_genome == hl.default_reference()


def test_constructor(self):
l = Locus.parse('1:100')
def test_call_rich_comparison():
val = Locus(1, 1)
expr = hl.locus('1', 1)

self.assertEqual(l, Locus('1', 100))
self.assertEqual(l, Locus(1, 100))
self.assertEqual(l.reference_genome, hl.default_reference())
assert hl.eval(val == expr)
assert hl.eval(expr == val)
19 changes: 19 additions & 0 deletions hail/python/test/hail/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,22 @@ def test_hadoop_ls_negated_group(glob_tests_directory):
glob_tests_directory + '/abc/ghi/?23']
actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/ghi/[!1]23')]
assert set(actual) == set(expected)


def test_struct_rich_comparison():
"""Asserts comparisons between structs and struct expressions are symmetric"""
struct = hl.Struct(
locus=hl.Locus(contig=10, position=60515, reference_genome='GRCh37'),
alleles=['C', 'T']
)

expr = hl.struct(
locus=hl.locus(contig='10', pos=60515, reference_genome='GRCh37'),
alleles=['C', 'T']
)

assert hl.eval(struct == expr) and hl.eval(expr == struct)
assert hl.eval(struct >= expr) and hl.eval(expr >= struct)
assert hl.eval(struct <= expr) and hl.eval(expr <= struct)
assert not (hl.eval(struct < expr) or hl.eval(expr < struct))
assert not (hl.eval(struct > expr) or hl.eval(expr > struct))

0 comments on commit 07c4930

Please sign in to comment.