Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #13045 - Comparison of Structs and Struct Exprs #13226

Merged
merged 6 commits into from Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion hail/python/hail/expr/functions.py
Expand Up @@ -233,11 +233,17 @@ def literal(x: Any, dtype: Optional[Union[HailType, str]] = None):
-------
:class:`.Expression`
"""
wrapper = {'has_expr': False}
wrapper = {'has_expr': False, 'has_free_vars': False}

def typecheck_expr(t, x):
if isinstance(x, Expression):
wrapper['has_expr'] = True
wrapper['has_free_vars'] |= (
builtins.len(x._ir.free_vars) > 0 or
builtins.len(x._ir.free_agg_vars) > 0 or
builtins.len(x._ir.free_scan_vars) > 0
)

if x.dtype != t:
raise TypeError(f"'literal': type mismatch: expected '{t}', found '{x.dtype}'")
elif x._indices.source is not None:
Expand Down Expand Up @@ -265,6 +271,13 @@ def typecheck_expr(t, x):
raise TypeError("'literal': object did not match the passed type '{}'"
.format(dtype)) from e

if wrapper['has_free_vars']:
raise ValueError(
"'literal' cannot be used with hail expressions that depend "
"on other expressions. Use expression 'x' directly "
"instead of passing it to 'literal'."
)

if wrapper['has_expr']:
return literal(hl.eval(to_expr(x, dtype)), dtype)

Expand Down
6 changes: 3 additions & 3 deletions hail/python/hail/genetics/call.py
Expand Up @@ -70,9 +70,9 @@ def __repr__(self):
return 'Call(alleles=%s, phased=%s)' % (self._alleles, self._phased)

def __eq__(self, other):
return (isinstance(other, Call)
and self._phased == other._phased
and self._alleles == other._alleles)
return ( self._phased == other._phased and
self._alleles == other._alleles
) if isinstance(other, Call) else NotImplemented

def __hash__(self):
return hash(self._phased) ^ hash(tuple(self._alleles))
Expand Down
8 changes: 4 additions & 4 deletions hail/python/hail/genetics/locus.py
Expand Up @@ -51,10 +51,10 @@ def __repr__(self):
return 'Locus(contig=%s, position=%s, reference_genome=%s)' % (self.contig, self.position, self._rg)

def __eq__(self, other):
return (isinstance(other, Locus)
and self._contig == other._contig
and self._position == other._position
and self._rg == other._rg)
return ( self._contig == other._contig and
self._position == other._position and
self._rg == other._rg
) if isinstance(other, Locus) else NotImplemented

def __hash__(self):
return hash(self._contig) ^ hash(self._position) ^ hash(self._rg)
Expand Down
11 changes: 6 additions & 5 deletions hail/python/hail/utils/interval.py
Expand Up @@ -67,11 +67,12 @@ def __repr__(self):
.format(repr(self.start), repr(self.end), repr(self.includes_start), repr(self._includes_end))

def __eq__(self, other):
return (isinstance(other, Interval)
and self._start == other._start
and self._end == other._end
and self._includes_start == other._includes_start
and self._includes_end == other._includes_end)
return ( self._start == other._start and
self._end == other._end and
self._includes_start == other._includes_start and
self._includes_end == other._includes_end
) if isinstance(other, Interval) else NotImplemented


def __hash__(self):
return hash(self._start) ^ hash(self._end) ^ hash(self._includes_start) ^ hash(self._includes_end)
Expand Down
8 changes: 5 additions & 3 deletions hail/python/hail/utils/linkedlist.py
Expand Up @@ -42,13 +42,15 @@ def __iter__(self):
return ListIterator(self.node)

def __str__(self):
return 'List({})'.format(', '.join(str(x) for x in self))
return f'''List({', '.join(str(x) for x in self)})'''

def __repr__(self):
return 'List({})'.format(', '.join(repr(x) for x in self))
return f'''List({', '.join(repr(x) for x in self)})'''

def __eq__(self, other):
return isinstance(other, LinkedList) and list(self) == list(other)
return list(self) == list(other) \
if isinstance(other, LinkedList) \
else NotImplemented

def __ne__(self, other):
return not self.__eq__(other)
Expand Down
4 changes: 3 additions & 1 deletion hail/python/hail/utils/struct.py
Expand Up @@ -94,7 +94,9 @@ def __str__(self):
)

def __eq__(self, other):
return isinstance(other, Struct) and self._fields == other._fields
return self._fields == other._fields \
if isinstance(other, Struct) \
else NotImplemented
Comment on lines +97 to +99
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def __hash__(self):
return 37 + hash(tuple(sorted(self._fields.items())))
Expand Down
7 changes: 7 additions & 0 deletions hail/python/test/hail/expr/test_functions.py
Expand Up @@ -85,3 +85,10 @@ def test_array():

with pytest.raises(ValueError, match='array: only one dimensional ndarrays are supported: ndarray<float64, 2>'):
hl.eval(hl.array(hl.nd.array([[1.0], [2.0]])))


def test_literal_free_vars():
"Give better error messages in response to code written by ChatGPT"
array = hl.literal([1, 2, 3])
with pytest.raises(ValueError, match='expressions that depend on other expressions'):
array.map(hl.literal)
7 changes: 7 additions & 0 deletions hail/python/test/hail/genetics/test_call.py
Expand Up @@ -91,3 +91,10 @@ def test_zeroploid(self):
"Calls with greater than 2 alleles are not supported.",
Call,
[1, 1, 1, 1])

def test_call_rich_comparison():
val = Call([0, 0])
expr = hl.call(0, 0)

assert hl.eval(val == expr)
assert hl.eval(expr == val)
23 changes: 13 additions & 10 deletions hail/python/test/hail/genetics/test_locus.py
@@ -1,14 +1,17 @@
import unittest

from hail.genetics import Locus
import hail as hl
from hail.genetics import *
from ..helpers import *

class Tests(unittest.TestCase):
def test_constructor():
l = Locus.parse('1:100')

assert l == Locus('1', 100)
assert l == Locus(1, 100)
assert l.reference_genome == hl.default_reference()


def test_constructor(self):
l = Locus.parse('1:100')
def test_call_rich_comparison():
val = Locus(1, 1)
expr = hl.locus('1', 1)

self.assertEqual(l, Locus('1', 100))
self.assertEqual(l, Locus(1, 100))
self.assertEqual(l.reference_genome, hl.default_reference())
assert hl.eval(val == expr)
assert hl.eval(expr == val)
19 changes: 19 additions & 0 deletions hail/python/test/hail/utils/test_utils.py
Expand Up @@ -421,3 +421,22 @@ def test_hadoop_ls_negated_group(glob_tests_directory):
glob_tests_directory + '/abc/ghi/?23']
actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/ghi/[!1]23')]
assert set(actual) == set(expected)


def test_struct_rich_comparison():
"""Asserts comparisons between structs and struct expressions are symmetric"""
struct = hl.Struct(
locus=hl.Locus(contig=10, position=60515, reference_genome='GRCh37'),
alleles=['C', 'T']
)

expr = hl.struct(
locus=hl.locus(contig='10', pos=60515, reference_genome='GRCh37'),
alleles=['C', 'T']
)

assert hl.eval(struct == expr) and hl.eval(expr == struct)
assert hl.eval(struct >= expr) and hl.eval(expr >= struct)
assert hl.eval(struct <= expr) and hl.eval(expr <= struct)
assert not (hl.eval(struct < expr) or hl.eval(expr < struct))
assert not (hl.eval(struct > expr) or hl.eval(expr > struct))