Skip to content

Commit

Permalink
refactor(ir): use dt.normalize() to construct literals
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and jcrist committed Jun 16, 2023
1 parent 5619ce0 commit bf72f16
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 63 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_vectorized_udf.py
Expand Up @@ -548,7 +548,7 @@ def test_elementwise_udf_named_destruct(udf_backend, udf_alltypes):
add_one_struct_udf = create_add_one_struct_udf(
result_formatter=lambda v1, v2: (v1, v2)
)
with pytest.raises(TypeError, match=r".*cannot be inferred.*"):
with pytest.raises(TypeError, match=r"Unable to infer datatype of"):
udf_alltypes.mutate(
new_struct=add_one_struct_udf(udf_alltypes['double_col']).destructure()
)
Expand Down
3 changes: 3 additions & 0 deletions ibis/expr/datatypes/cast.py
Expand Up @@ -80,6 +80,9 @@ def can_cast_to_differently_signed_integer_type(
source: dt.Integer, target: dt.Integer, value: int | None = None, **kwargs
) -> bool:
if value is not None:
# TODO(kszucs): we may not need to actually check the value since the
# literal construction now checks for bounds and doesn't use castable()
# anymore
return target.bounds.lower <= value <= target.bounds.upper
else:
return (target.bounds.upper - target.bounds.lower) >= (
Expand Down
3 changes: 3 additions & 0 deletions ibis/expr/datatypes/core.py
Expand Up @@ -397,6 +397,9 @@ class Bounds(NamedTuple):
lower: int
upper: int

def __contains__(self, value: int) -> bool:
return self.lower <= value <= self.upper


@public
class Numeric(DataType):
Expand Down
29 changes: 23 additions & 6 deletions ibis/expr/datatypes/value.py
Expand Up @@ -27,7 +27,7 @@
@lazy_singledispatch
def infer(value: Any) -> dt.DataType:
"""Infer the corresponding ibis dtype for a python object."""
raise InputTypeError(value)
raise InputTypeError(f"Unable to infer datatype of {value!r}")


# TODO(kszucs): support NamedTuples and dataclasses instead of OrderedDict
Expand Down Expand Up @@ -250,9 +250,22 @@ def normalize(typ, value):
if dtype.is_boolean():
return bool(value)
elif dtype.is_integer():
return int(value)
try:
value = int(value)
except ValueError:
raise TypeError("Unable to normalize {value!r} to {dtype!r}")
if value not in dtype.bounds:
raise TypeError(
f"Value {value} is out of bounds for type {dtype!r} "
f"(bounds: {dtype.bounds})"
)
else:
return value
elif dtype.is_floating():
return float(value)
try:
return float(value)
except ValueError:
raise TypeError("Unable to normalize {value!r} to {dtype!r}")
elif dtype.is_string() and not dtype.is_json():
return str(value)
elif dtype.is_decimal():
Expand All @@ -267,9 +280,13 @@ def normalize(typ, value):
elif dtype.is_map():
return frozendict({k: normalize(dtype.value_type, v) for k, v in value.items()})
elif dtype.is_struct():
return frozendict(
{k: normalize(dtype[k], v) for k, v in value.items() if k in dtype.fields}
)
if not isinstance(value, Mapping):
raise TypeError(f"Unable to normalize {dtype} from non-mapping {value!r}")
if missing_keys := (dtype.keys() - value.keys()):
raise TypeError(
f"Unable to normalize {value!r} to {dtype} because of missing keys {missing_keys!r}"
)
return frozendict({k: normalize(t, value[k]) for k, t in dtype.items()})
elif dtype.is_geospatial():
if isinstance(value, (tuple, list)):
if dtype.is_point():
Expand Down
1 change: 1 addition & 0 deletions ibis/expr/operations/generic.py
Expand Up @@ -223,6 +223,7 @@ def name(self):
return repr(self.value)


# TODO(kszucs): remove
@public
class NullLiteral(Literal, Singleton):
"""Typeless NULL literal."""
Expand Down
36 changes: 2 additions & 34 deletions ibis/expr/rules.py
Expand Up @@ -160,41 +160,9 @@ def literal(dtype, value, **kwargs):
if isinstance(value, ops.Literal):
return value

try:
inferred_dtype = dt.infer(value)
except com.InputTypeError:
has_inferred = False
else:
has_inferred = True

if dtype is None:
has_explicit = False
else:
has_explicit = True
# TODO(kszucs): handle class-like dtype definitions here explicitly
explicit_dtype = dt.dtype(dtype)

if has_explicit and has_inferred:
try:
# ensure type correctness: check that the inferred dtype is
# implicitly castable to the explicitly given dtype and value
dtype = dt.cast(inferred_dtype, target=explicit_dtype, value=value)
except com.IbisTypeError:
raise TypeError(f'Value {value!r} cannot be safely coerced to `{dtype}`')
elif has_explicit:
dtype = explicit_dtype
elif has_inferred:
dtype = inferred_dtype
else:
raise com.IbisTypeError(
'The datatype of value {!r} cannot be inferred, try '
'passing it explicitly with the `type` keyword.'.format(value)
)

if dtype.is_null():
return ops.NullLiteral()

dtype = dt.infer(value) if dtype is None else dt.dtype(dtype)
value = dt.normalize(dtype, value)

return ops.Literal(value, dtype=dtype)


Expand Down
8 changes: 4 additions & 4 deletions ibis/expr/tests/test_rules.py
Expand Up @@ -46,10 +46,10 @@ def test_invalid_datatype(value, expected):
assert rlz.datatype(value)


def test_invalid_literal():
msg = "Value 1 cannot be safely coerced to `string`"
with pytest.raises(TypeError, match=msg):
rlz.literal(dt.string, 1)
def test_string_literal_from_intergerinvalid_literal():
lit = rlz.literal(dt.string, 1)
assert type(lit.value) is str
assert lit.value == "1"


@pytest.mark.parametrize(
Expand Down
8 changes: 5 additions & 3 deletions ibis/tests/expr/test_literal.py
Expand Up @@ -97,6 +97,7 @@ def test_normalized_underlying_value(userinput, literal_type, expected_type):
'value',
[
dict(field1='value1', field2=3.14),
dict(field1='value1', field2='3.14'), # coerceable type
dict(field1='value1', field2=1), # coerceable type
dict(field2=2.72, field1='value1'), # wrong field order
dict(field1='value1', field2=3.14, field3='extra'), # extra field
Expand All @@ -105,21 +106,22 @@ def test_normalized_underlying_value(userinput, literal_type, expected_type):
def test_struct_literal(value):
typestr = "struct<field1: string, field2: float64>"
a = ibis.struct(value, type=typestr)
assert a.op().value == frozendict(field1=value['field1'], field2=value['field2'])
assert a.op().value == frozendict(
field1=str(value['field1']), field2=float(value['field2'])
)
assert a.type() == dt.dtype(typestr)


@pytest.mark.parametrize(
'value',
[
dict(field1='value1', field2='3.14'), # non-coerceable type
dict(field1='value1', field3=3.14), # wrong field name
dict(field1='value1'), # missing field
],
)
def test_struct_literal_non_castable(value):
typestr = "struct<field1: string, field2: float64>"
with pytest.raises((KeyError, TypeError, ibis.common.exceptions.IbisTypeError)):
with pytest.raises(TypeError, match="Unable to normalize"):
ibis.struct(value, type=typestr)


Expand Down
30 changes: 16 additions & 14 deletions ibis/tests/expr/test_value_exprs.py
Expand Up @@ -25,19 +25,18 @@
from ibis.expr import api
from ibis.tests.util import assert_equal

# def test_null():
# expr = ibis.literal(None)
# assert isinstance(expr, ir.NullScalar)
# assert isinstance(expr.op(), ops.NullLiteral)
# assert expr._arg.value is None

def test_null():
expr = ibis.literal(None)
assert isinstance(expr, ir.NullScalar)
assert isinstance(expr.op(), ops.NullLiteral)
assert expr._arg.value is None
# expr2 = ibis.null()
# assert_equal(expr, expr2)

expr2 = ibis.null()
assert_equal(expr, expr2)

assert expr is expr2
assert expr.type().equals(dt.null)
assert expr2.type().equals(dt.null)
# assert expr is expr2
# assert expr.type().equals(dt.null)
# assert expr2.type().equals(dt.null)


def test_literal_mixed_type_fails():
Expand Down Expand Up @@ -234,15 +233,18 @@ def test_simple_map_operations():
(32768, 'int16'),
(2147483647, 'int16'),
(2147483648, 'int32'),
('foo', 'double'),
],
)
def test_literal_with_non_coercible_type(value, expected_type):
expected_msg = 'Value .* cannot be safely coerced to .*'
with pytest.raises(TypeError, match=expected_msg):
with pytest.raises(TypeError, match="out of bounds"):
ibis.literal(value, type=expected_type)


def test_literal_double_from_string_fails():
with pytest.raises(TypeError, match="Unable to normalize"):
ibis.literal('foo', type='double')


def test_list_and_tuple_literals():
what = [1, 2, 1000]
expr = api.literal(what)
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/sql/test_select_sql.py
Expand Up @@ -409,7 +409,7 @@ def test_case_in_projection(alltypes, snapshot):
expr = t[expr.name('col1'), expr2.name('col2'), t]

snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot)
assert_decompile_roundtrip(expr, snapshot, check_equality=False)


def test_identifier_quoting(snapshot):
Expand Down

0 comments on commit bf72f16

Please sign in to comment.