Skip to content

Commit

Permalink
support non-inferable types in ibis.literal
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and xmnlab committed Feb 21, 2019
1 parent 688caee commit 45ddbf8
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 49 deletions.
3 changes: 3 additions & 0 deletions ibis/expr/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ def test_time_valid():
])
def test_infer_dtype(value, expected_dtype):
assert dt.infer(value) == expected_dtype
# test literal creation
value = ibis.literal(value, type=expected_dtype)
assert value.type() == expected_dtype


@pytest.mark.parametrize(('source', 'target'), [
Expand Down
110 changes: 69 additions & 41 deletions ibis/expr/tests/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,24 @@ def test_unicode():
assert False


@pytest.mark.parametrize(
['value', 'expected_type'],
[
(5, 'int8'),
(127, 'int8'),
(128, 'int16'),
(32767, 'int16'),
(32768, 'int32'),
(2147483647, 'int32'),
(2147483648, 'int64'),
(-5, 'int8'),
(-128, 'int8'),
(-129, 'int16'),
(-32769, 'int32'),
(-2147483649, 'int64'),
(1.5, 'double'),
('foo', 'string'),
]
)
def test_literal_cases(value, expected_type):
@pytest.mark.parametrize(['value', 'expected_type'], [
(5, 'int8'),
(127, 'int8'),
(128, 'int16'),
(32767, 'int16'),
(32768, 'int32'),
(2147483647, 'int32'),
(2147483648, 'int64'),
(-5, 'int8'),
(-128, 'int8'),
(-129, 'int16'),
(-32769, 'int32'),
(-2147483649, 'int64'),
(1.5, 'double'),
('foo', 'string'),
([1, 2, 3], 'array<int8>')
])
def test_literal_with_implicit_type(value, expected_type):
expr = ibis.literal(value)

assert isinstance(expr, ir.ScalarExpr)
Expand All @@ -79,25 +77,41 @@ def test_literal_cases(value, expected_type):
assert expr.op().value is value


@pytest.mark.parametrize(
['value', 'expected_type'],
[
(5, 'int16'),
(127, 'double'),
(128, 'int64'),
(32767, 'double'),
(32768, 'float'),
(2147483647, 'int64'),
(-5, 'int16'),
(-128, 'int32'),
(-129, 'int64'),
(-32769, 'float'),
(-2147483649, 'double'),
(1.5, 'double'),
('foo', 'string'),
]
)
def test_literal_with_different_type(value, expected_type):
pointA = (1, 2)
pointB = (-3, 4)
pointC = (5, 19)
lineAB = [pointA, pointB]
lineBC = [pointB, pointC]
lineCA = [pointC, pointA]
polygon1 = [lineAB, lineBC, lineCA]
polygon2 = [lineAB, lineBC, lineCA]
multipolygon1 = [polygon1, polygon2]


@pytest.mark.parametrize(['value', 'expected_type'], [
(5, 'int16'),
(127, 'double'),
(128, 'int64'),
(32767, 'double'),
(32768, 'float'),
(2147483647, 'int64'),
(-5, 'int16'),
(-128, 'int32'),
(-129, 'int64'),
(-32769, 'float'),
(-2147483649, 'double'),
(1.5, 'double'),
('foo', 'string'),
(list(pointA), 'point'),
(tuple(pointA), 'point'),
(list(lineAB), 'line'),
(tuple(lineAB), 'line'),
(list(polygon1), 'polygon'),
(tuple(polygon1), 'polygon'),
(list(multipolygon1), 'multipolygon'),
(tuple(multipolygon1), 'multipolygon')
])
def test_literal_with_explicit_type(value, expected_type):
expr = ibis.literal(value, type=expected_type)
assert expr.type().equals(dt.validate_type(expected_type))

Expand Down Expand Up @@ -181,11 +195,25 @@ def test_simple_map_operations():
('foo', 'double'),
]
)
def test_literal_with_different_type_failure(value, expected_type):
with pytest.raises(TypeError):
def test_literal_with_non_coercible_type(value, expected_type):
expected_msg = 'Value .* cannot be safely coerced to .*'
with pytest.raises(TypeError, match=expected_msg):
ibis.literal(value, type=expected_type)


def test_non_inferrable_literal():
expected_msg = ('The datatype of value .* cannot be inferred, try '
'passing it explicitly with the `type` keyword.')

value = tuple(pointA)

with pytest.raises(TypeError, match=expected_msg):
ibis.literal(value)

point = ibis.literal(value, type='point')
assert point.type() == dt.point


def test_literal_list():
what = [1, 2, 1000]
expr = api.literal(what)
Expand Down
33 changes: 25 additions & 8 deletions ibis/expr/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,18 +886,35 @@ def literal(value, type=None):
if hasattr(value, 'op') and isinstance(value.op(), ops.Literal):
return value

if value is null:
dtype = dt.null
try:
inferred_dtype = dt.infer(value)
except com.InputTypeError:
has_inferred = False
else:
dtype = dt.infer(value)
has_inferred = True

if type is not None:
if type is None:
has_explicit = False
else:
has_explicit = True
explicit_dtype = dt.dtype(type)

if has_explicit and has_inferred:
try:
# check that dtype is implicitly castable to explicitly given dtype
dtype = dtype.cast(type, value=value)
# ensure type correctness: check that the inferred dtype is
# implicitly castable to the explicitly given dtype and value
dtype = inferred_dtype.cast(explicit_dtype, value=value)
except com.IbisTypeError:
raise TypeError('Value {!r} cannot be safely coerced '
'to {}'.format(value, type))
raise TypeError('Value {!r} cannot be safely coerced to {}'
.format(value, type))
elif has_explicit:
dtype = explicit_dtype
elif has_inferred:
dtype = inferred_dtype
else:
raise TypeError('The datatype of value {!r} cannot be inferred, try '
'passing it explicitly with the `type` keyword.'
.format(value))

if dtype is dt.null:
return null().cast(dtype)
Expand Down

0 comments on commit 45ddbf8

Please sign in to comment.