Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expressions for Map/Struct types and columns #1166

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 61 additions & 2 deletions ibis/expr/api.py
Expand Up @@ -40,6 +40,8 @@
TimestampValue, TimestampScalar, TimestampColumn,
DateValue, TimeValue,
ArrayValue, ArrayScalar, ArrayColumn,
MapValue, MapScalar, MapColumn,
StructValue, StructScalar, StructColumn,
CategoryValue, unnamed, as_value_expr, literal,
param, null, sequence)

Expand Down Expand Up @@ -1667,11 +1669,11 @@ def _string_contains(arg, substr):

Parameters
----------
substr
substr : str or ibis.expr.types.StringValue

Returns
-------
contains : boolean
contains : ibis.expr.types.BooleanValue
"""
return arg.find(substr) >= 0

Expand Down Expand Up @@ -1745,6 +1747,19 @@ def _string_getitem(self, key):


def _array_slice(array, index):
"""Slice or index `array` at `index`.

Parameters
----------
index : int or ibis.expr.types.IntegerValue or slice

Returns
-------
sliced_array : ibis.expr.types.ValueExpr
If `index` is an ``int`` or :class:`~ibis.expr.types.IntegerValue` then
the return type is the element type of `array`. If `index` is a
``slice`` then the return type is the same type as the input.
"""
if isinstance(index, slice):
start = index.start
stop = index.stop
Expand Down Expand Up @@ -1778,6 +1793,50 @@ def _array_slice(array, index):

_add_methods(ArrayValue, _array_column_methods)

# ---------------------------------------------------------------------
# Map API


_map_column_methods = dict(
length=_unary_op('length', _ops.MapLength),
__getitem__=_binop_expr('__getitem__', _ops.MapValueForKey),
keys=_unary_op('keys', _ops.MapKeys),
values=_unary_op('values', _ops.MapValues),
__add__=_binop_expr('__add__', _ops.MapConcat),
__radd__=toolz.flip(_binop_expr('__radd__', _ops.MapConcat)),
)

_add_methods(MapValue, _map_column_methods)

# ---------------------------------------------------------------------
# Struct API


def _struct_get_field(expr, field_name):
"""Get the `field_name` field from the ``Struct`` expression `expr`.

Parameters
----------
field_name : str
The name of the field to access from the ``Struct`` typed expression
`expr`. Must be a Python ``str`` type; programmatic struct field
access is not yet supported.

Returns
-------
value_expr : ibis.expr.types.ValueExpr
An expression with the type of the field being accessed.
"""
return _ops.StructField(expr, field_name).to_expr()


_struct_column_methods = dict(
__getattr__=_struct_get_field,
__getitem__=_struct_get_field,
)

_add_methods(StructValue, _struct_column_methods)


# ---------------------------------------------------------------------
# Timestamp API
Expand Down
36 changes: 30 additions & 6 deletions ibis/expr/datatypes.py
Expand Up @@ -502,8 +502,21 @@ class Struct(DataType):

def __init__(self, names, types, nullable=True):
super(Struct, self).__init__(nullable=nullable)
self.names = list(names)
self.types = list(map(validate_type, types))
self.pairs = OrderedDict(zip(names, types))

@property
def names(self):
return self.pairs.keys()

@property
def types(self):
return self.pairs.values()

def __getitem__(self, key):
return self.pairs[key]

def __iter__(self):
return iter(self.pairs.items())

def __repr__(self):
return '{0}({1})'.format(
Expand All @@ -520,12 +533,18 @@ def __str__(self):
)

def _equal_part(self, other, cache=None):
return self.names == other.names and self.types == other.types
return self.names == other.names and (
left.equals(right, cache=cache)
for left, right in zip(self.types, other.types)
)

@classmethod
def from_tuples(self, pairs):
return Struct(*map(list, zip(*pairs)))

def valid_literal(self, value):
return isinstance(value, OrderedDict)


@parametric
class Array(Variadic):
Expand All @@ -544,7 +563,7 @@ def _equal_part(self, other, cache=None):
return self.value_type.equals(other.value_type, cache=cache)

def valid_literal(self, value):
return isinstance(value, (list, tuple))
return isinstance(value, list)


@parametric
Expand All @@ -563,7 +582,7 @@ def _equal_part(self, other, cache=None):


@parametric
class Map(DataType):
class Map(Variadic):

def __init__(self, key_type, value_type, nullable=True):
super(Map, self).__init__(nullable=nullable)
Expand All @@ -590,6 +609,9 @@ def _equal_part(self, other, cache=None):
self.value_type.equals(other.value_type, cache=cache)
)

def valid_literal(self, value):
return isinstance(value, dict)


# ---------------------------------------------------------------------

Expand Down Expand Up @@ -939,7 +961,9 @@ def type(self):
def validate_type(t):
if isinstance(t, DataType):
return t
return TypeParser(t).parse()
elif isinstance(t, six.string_types):
return TypeParser(t).parse()
raise TypeError('Value {!r} is not a valid type or string'.format(t))


def array_type(t):
Expand Down
73 changes: 64 additions & 9 deletions ibis/expr/operations.py
Expand Up @@ -677,7 +677,7 @@ def _sum_output_type(self):
elif isinstance(arg, ir.FloatingValue):
t = 'double'
elif isinstance(arg, ir.DecimalValue):
t = dt.Decimal(arg._precision, 38)
t = dt.Decimal(arg.meta.precision, 38)
else:
raise TypeError(arg)
return t
Expand All @@ -686,7 +686,7 @@ def _sum_output_type(self):
def _mean_output_type(self):
arg = self.args[0]
if isinstance(arg, ir.DecimalValue):
t = dt.Decimal(arg._precision, 38)
t = dt.Decimal(arg.meta.precision, 38)
elif isinstance(arg, ir.NumericValue):
t = 'double'
else:
Expand Down Expand Up @@ -763,13 +763,13 @@ class Variance(VarianceBase):

def _decimal_scalar_ctor(precision, scale):
out_type = dt.Decimal(precision, scale)
return ir.DecimalScalar._make_constructor(out_type)
return out_type.scalar_type()


def _min_max_output_rule(self):
arg = self.args[0]
if isinstance(arg, ir.DecimalValue):
t = dt.Decimal(arg._precision, 38)
t = dt.Decimal(arg.meta.precision, 38)
else:
t = arg.type()

Expand Down Expand Up @@ -2440,9 +2440,10 @@ class ArraySlice(ValueOp):
class ArrayIndex(ValueOp):

input_type = [rules.array(dt.any), rules.integer(name='index')]
output_type = rules.array_output(
lambda self: self.args[0].type().value_type
)

def output_type(self):
value_type = self.args[0].type().value_type
return rules.shape_like(self.args[0], value_type)


def _array_binop_invariant_output_type(self):
Expand All @@ -2465,16 +2466,70 @@ def _array_binop_invariant_output_type(self):
class ArrayConcat(ValueOp):

input_type = [rules.array(dt.any), rules.array(dt.any)]
output_type = rules.array_output(_array_binop_invariant_output_type)

def output_type(self):
result_type = _array_binop_invariant_output_type(self)
return rules.shape_like(self.args[0], result_type)


class ArrayRepeat(ValueOp):

input_type = [rules.array(dt.any), integer(name='times')]
output_type = rules.array_output(lambda self: self.args[0].type())

def output_type(self):
array_type = self.args[0].type()
return rules.shape_like(self.args[0], array_type)


class ArrayCollect(Reduction):

input_type = [rules.column]
output_type = rules.scalar_output(_array_reduced_type)


class MapLength(ValueOp):

input_type = [rules.map(dt.any, dt.any)]
output_type = rules.shape_like_arg(0, 'int64')


class MapValueForKey(ValueOp):

input_type = [
rules.map(dt.any, dt.any),
rules.one_of((dt.string, dt.int_), name='key')
]

def output_type(self):
map_type = self.args[0].type()
return rules.shape_like(self.args[0], map_type.value_type)


class MapKeys(ValueOp):

input_type = [rules.map(dt.any, dt.any)]
output_type = rules.type_of_arg(0)


class MapValues(ValueOp):

input_type = [rules.map(dt.any, dt.any)]
output_type = rules.type_of_arg(0)


class MapConcat(ValueOp):

input_type = [rules.map(dt.any, dt.any), rules.map(dt.any, dt.any)]
output_type = rules.type_of_arg(0)


class StructField(ValueOp):

input_type = [
rules.struct,
rules.instance_of(six.string_types, name='field')
]

def output_type(self):
struct_type = self.args[0].type()
return rules.shape_like(self.args[0], struct_type[self.field])