diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index ffa37438f70..938a4a6ec6c 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -233,11 +233,17 @@ def literal(x: Any, dtype: Optional[Union[HailType, str]] = None): ------- :class:`.Expression` """ - wrapper = {'has_expr': False} + wrapper = {'has_expr': False, 'has_free_vars': False} def typecheck_expr(t, x): if isinstance(x, Expression): wrapper['has_expr'] = True + wrapper['has_free_vars'] |= ( + builtins.len(x._ir.free_vars) > 0 or + builtins.len(x._ir.free_agg_vars) > 0 or + builtins.len(x._ir.free_scan_vars) > 0 + ) + if x.dtype != t: raise TypeError(f"'literal': type mismatch: expected '{t}', found '{x.dtype}'") elif x._indices.source is not None: @@ -265,6 +271,13 @@ def typecheck_expr(t, x): raise TypeError("'literal': object did not match the passed type '{}'" .format(dtype)) from e + if wrapper['has_free_vars']: + raise ValueError( + "'literal' cannot be used with hail expressions that depend " + "on other expressions. Use expression 'x' directly " + "instead of passing it to 'literal'." + ) + if wrapper['has_expr']: return literal(hl.eval(to_expr(x, dtype)), dtype) diff --git a/hail/python/hail/genetics/call.py b/hail/python/hail/genetics/call.py index 36c9efd4742..e31d67263e9 100644 --- a/hail/python/hail/genetics/call.py +++ b/hail/python/hail/genetics/call.py @@ -70,9 +70,9 @@ def __repr__(self): return 'Call(alleles=%s, phased=%s)' % (self._alleles, self._phased) def __eq__(self, other): - return (isinstance(other, Call) - and self._phased == other._phased - and self._alleles == other._alleles) + return ( self._phased == other._phased and + self._alleles == other._alleles + ) if isinstance(other, Call) else NotImplemented def __hash__(self): return hash(self._phased) ^ hash(tuple(self._alleles)) diff --git a/hail/python/hail/genetics/locus.py b/hail/python/hail/genetics/locus.py index 68145dc8ccf..97029f60a6a 100644 --- a/hail/python/hail/genetics/locus.py +++ b/hail/python/hail/genetics/locus.py @@ -51,10 +51,10 @@ def __repr__(self): return 'Locus(contig=%s, position=%s, reference_genome=%s)' % (self.contig, self.position, self._rg) def __eq__(self, other): - return (isinstance(other, Locus) - and self._contig == other._contig - and self._position == other._position - and self._rg == other._rg) + return ( self._contig == other._contig and + self._position == other._position and + self._rg == other._rg + ) if isinstance(other, Locus) else NotImplemented def __hash__(self): return hash(self._contig) ^ hash(self._position) ^ hash(self._rg) diff --git a/hail/python/hail/utils/interval.py b/hail/python/hail/utils/interval.py index 0f1733805b3..0e08041e5c8 100644 --- a/hail/python/hail/utils/interval.py +++ b/hail/python/hail/utils/interval.py @@ -67,11 +67,12 @@ def __repr__(self): .format(repr(self.start), repr(self.end), repr(self.includes_start), repr(self._includes_end)) def __eq__(self, other): - return (isinstance(other, Interval) - and self._start == other._start - and self._end == other._end - and self._includes_start == other._includes_start - and self._includes_end == other._includes_end) + return ( self._start == other._start and + self._end == other._end and + self._includes_start == other._includes_start and + self._includes_end == other._includes_end + ) if isinstance(other, Interval) else NotImplemented + def __hash__(self): return hash(self._start) ^ hash(self._end) ^ hash(self._includes_start) ^ hash(self._includes_end) diff --git a/hail/python/hail/utils/linkedlist.py b/hail/python/hail/utils/linkedlist.py index 508278dc130..d1a30ab4aab 100644 --- a/hail/python/hail/utils/linkedlist.py +++ b/hail/python/hail/utils/linkedlist.py @@ -42,13 +42,15 @@ def __iter__(self): return ListIterator(self.node) def __str__(self): - return 'List({})'.format(', '.join(str(x) for x in self)) + return f'''List({', '.join(str(x) for x in self)})''' def __repr__(self): - return 'List({})'.format(', '.join(repr(x) for x in self)) + return f'''List({', '.join(repr(x) for x in self)})''' def __eq__(self, other): - return isinstance(other, LinkedList) and list(self) == list(other) + return list(self) == list(other) \ + if isinstance(other, LinkedList) \ + else NotImplemented def __ne__(self, other): return not self.__eq__(other) diff --git a/hail/python/hail/utils/struct.py b/hail/python/hail/utils/struct.py index 356edd025e6..a2ff7df7bec 100644 --- a/hail/python/hail/utils/struct.py +++ b/hail/python/hail/utils/struct.py @@ -94,7 +94,9 @@ def __str__(self): ) def __eq__(self, other): - return isinstance(other, Struct) and self._fields == other._fields + return self._fields == other._fields \ + if isinstance(other, Struct) \ + else NotImplemented def __hash__(self): return 37 + hash(tuple(sorted(self._fields.items()))) diff --git a/hail/python/test/hail/expr/test_functions.py b/hail/python/test/hail/expr/test_functions.py index c7e1a41fa2f..5019de46db7 100644 --- a/hail/python/test/hail/expr/test_functions.py +++ b/hail/python/test/hail/expr/test_functions.py @@ -85,3 +85,10 @@ def test_array(): with pytest.raises(ValueError, match='array: only one dimensional ndarrays are supported: ndarray'): hl.eval(hl.array(hl.nd.array([[1.0], [2.0]]))) + + +def test_literal_free_vars(): + "Give better error messages in response to code written by ChatGPT" + array = hl.literal([1, 2, 3]) + with pytest.raises(ValueError, match='expressions that depend on other expressions'): + array.map(hl.literal) diff --git a/hail/python/test/hail/genetics/test_call.py b/hail/python/test/hail/genetics/test_call.py index 63a0712a545..6ab0d82caa3 100644 --- a/hail/python/test/hail/genetics/test_call.py +++ b/hail/python/test/hail/genetics/test_call.py @@ -91,3 +91,10 @@ def test_zeroploid(self): "Calls with greater than 2 alleles are not supported.", Call, [1, 1, 1, 1]) + +def test_call_rich_comparison(): + val = Call([0, 0]) + expr = hl.call(0, 0) + + assert hl.eval(val == expr) + assert hl.eval(expr == val) diff --git a/hail/python/test/hail/genetics/test_locus.py b/hail/python/test/hail/genetics/test_locus.py index ebf0491aeb3..94488bd143e 100644 --- a/hail/python/test/hail/genetics/test_locus.py +++ b/hail/python/test/hail/genetics/test_locus.py @@ -1,14 +1,17 @@ -import unittest - +from hail.genetics import Locus import hail as hl -from hail.genetics import * -from ..helpers import * -class Tests(unittest.TestCase): +def test_constructor(): + l = Locus.parse('1:100') + + assert l == Locus('1', 100) + assert l == Locus(1, 100) + assert l.reference_genome == hl.default_reference() + - def test_constructor(self): - l = Locus.parse('1:100') +def test_call_rich_comparison(): + val = Locus(1, 1) + expr = hl.locus('1', 1) - self.assertEqual(l, Locus('1', 100)) - self.assertEqual(l, Locus(1, 100)) - self.assertEqual(l.reference_genome, hl.default_reference()) + assert hl.eval(val == expr) + assert hl.eval(expr == val) diff --git a/hail/python/test/hail/utils/test_utils.py b/hail/python/test/hail/utils/test_utils.py index 4cd614dec8b..61019be84f7 100644 --- a/hail/python/test/hail/utils/test_utils.py +++ b/hail/python/test/hail/utils/test_utils.py @@ -421,3 +421,22 @@ def test_hadoop_ls_negated_group(glob_tests_directory): glob_tests_directory + '/abc/ghi/?23'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/ghi/[!1]23')] assert set(actual) == set(expected) + + +def test_struct_rich_comparison(): + """Asserts comparisons between structs and struct expressions are symmetric""" + struct = hl.Struct( + locus=hl.Locus(contig=10, position=60515, reference_genome='GRCh37'), + alleles=['C', 'T'] + ) + + expr = hl.struct( + locus=hl.locus(contig='10', pos=60515, reference_genome='GRCh37'), + alleles=['C', 'T'] + ) + + assert hl.eval(struct == expr) and hl.eval(expr == struct) + assert hl.eval(struct >= expr) and hl.eval(expr >= struct) + assert hl.eval(struct <= expr) and hl.eval(expr <= struct) + assert not (hl.eval(struct < expr) or hl.eval(expr < struct)) + assert not (hl.eval(struct > expr) or hl.eval(expr > struct))