Skip to content

Commit

Permalink
Add support for {DataFrame,Series}.align (#3147)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Jun 27, 2022
1 parent 31bd6cc commit 60e19c3
Show file tree
Hide file tree
Showing 15 changed files with 949 additions and 318 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/dataframe/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ Reindexing / selection / label manipulation
:toctree: generated/

DataFrame.add_prefix
DataFrame.add_suffix
DataFrame.align
DataFrame.drop
DataFrame.drop_duplicates
DataFrame.duplicated
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/dataframe/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ Reindexing / selection / label manipulation
:toctree: generated/

Series.add_prefix
Series.add_suffix
Series.align
Series.drop
Series.drop_duplicates
Series.duplicated
Expand Down
214 changes: 91 additions & 123 deletions mars/dataframe/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,137 +40,75 @@
filter_index_value,
hash_index,
is_index_value_identical,
validate_axis,
)


class DataFrameIndexAlign(MapReduceOperand, DataFrameOperandMixin):
_op_type_ = OperandDef.DATAFRAME_INDEX_ALIGN

_index_min = AnyField("index_min")
_index_min_close = BoolField("index_min_close")
_index_max = AnyField("index_max")
_index_max_close = BoolField("index_max_close")
_index_shuffle_size = Int32Field("index_shuffle_size")
_column_min = AnyField("column_min")
_column_min_close = BoolField("column_min_close")
_column_max = AnyField("column_max")
_column_max_close = BoolField("column_max_close")
_column_shuffle_size = Int32Field("column_shuffle_size")
_column_shuffle_segments = ListField("column_shuffle_segments", FieldTypes.series)
index_min = AnyField("index_min")
index_min_close = BoolField("index_min_close")
index_max = AnyField("index_max")
index_max_close = BoolField("index_max_close")
index_shuffle_size = Int32Field("index_shuffle_size", default=None)
column_min = AnyField("column_min")
column_min_close = BoolField("column_min_close")
column_max = AnyField("column_max")
column_max_close = BoolField("column_max_close")
column_shuffle_size = Int32Field("column_shuffle_size", default=None)
column_shuffle_segments = ListField("column_shuffle_segments", FieldTypes.series)

_input = KeyField("input")
input = KeyField("input")

def __init__(
self,
index_min_max=None,
index_shuffle_size=None,
column_min_max=None,
column_shuffle_size=None,
column_shuffle_segments=None,
sparse=None,
dtype=None,
dtypes=None,
output_types=None,
**kw
self, index_min_max=None, column_min_max=None, output_types=None, **kw
):
if index_min_max is not None:
kw.update(
dict(
_index_min=index_min_max[0],
_index_min_close=index_min_max[1],
_index_max=index_min_max[2],
_index_max_close=index_min_max[3],
index_min=index_min_max[0],
index_min_close=index_min_max[1],
index_max=index_min_max[2],
index_max_close=index_min_max[3],
)
)
if column_min_max is not None:
kw.update(
dict(
_column_min=column_min_max[0],
_column_min_close=column_min_max[1],
_column_max=column_min_max[2],
_column_max_close=column_min_max[3],
column_min=column_min_max[0],
column_min_close=column_min_max[1],
column_max=column_min_max[2],
column_max_close=column_min_max[3],
)
)
super().__init__(
_index_shuffle_size=index_shuffle_size,
_column_shuffle_size=column_shuffle_size,
_column_shuffle_segments=column_shuffle_segments,
sparse=sparse,
_dtype=dtype,
_dtypes=dtypes,
_output_types=output_types,
**kw
)

@property
def index_min(self):
return self._index_min

@property
def index_min_close(self):
return self._index_min_close

@property
def index_max(self):
return self._index_max

@property
def index_max_close(self):
return self._index_max_close
super().__init__(_output_types=output_types, **kw)

@property
def index_min_max(self):
if getattr(self, "_index_min", None) is None:
if getattr(self, "index_min", None) is None:
return None
return (
self._index_min,
self._index_min_close,
self._index_max,
self._index_max_close,
self.index_min,
self.index_min_close,
self.index_max,
self.index_max_close,
)

@property
def index_shuffle_size(self):
return self._index_shuffle_size

@property
def column_min(self):
return self._column_min

@property
def column_min_close(self):
return self._column_min_close

@property
def column_max(self):
return self._column_max

@property
def column_max_close(self):
return self._column_max_close

@property
def column_min_max(self):
if getattr(self, "_column_min", None) is None:
if getattr(self, "column_min", None) is None:
return None
return (
self._column_min,
self._column_min_close,
self._column_max,
self._column_max_close,
self.column_min,
self.column_min_close,
self.column_max,
self.column_max_close,
)

@property
def column_shuffle_size(self):
return self._column_shuffle_size

@property
def column_shuffle_segments(self):
return self._column_shuffle_segments

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
self._input = self._inputs[0]
self.input = self._inputs[0]

def build_map_chunk_kw(self, inputs, **kw):
if kw.get("index_value", None) is None and inputs[0].index_value is not None:
Expand Down Expand Up @@ -205,7 +143,7 @@ def build_map_chunk_kw(self, inputs, **kw):
kw["dtypes"] = input_dtypes[kw["columns_value"].to_pandas()]
column_shuffle_size = self.column_shuffle_size
if column_shuffle_size is not None:
self._column_shuffle_segments = hash_dtypes(
self.column_shuffle_segments = hash_dtypes(
input_dtypes, column_shuffle_size
)
else:
Expand Down Expand Up @@ -873,46 +811,77 @@ def _gen_dataframe_chunks(splits, out_shape, left_or_right, df):
return out_chunks


def align_dataframe_dataframe(left, right):
def align_dataframe_dataframe(left, right, axis=None):
left_index_chunks = [c.index_value for c in left.cix[:, 0]]
left_columns_chunks = [c.columns_value for c in left.cix[0, :]]
right_index_chunks = [c.index_value for c in right.cix[:, 0]]
left_columns_chunks = [c.columns_value for c in left.cix[0, :]]
right_columns_chunks = [c.columns_value for c in right.cix[0, :]]

index_splits, index_chunk_shape = _calc_axis_splits(
left.index_value, right.index_value, left_index_chunks, right_index_chunks
)
axis = validate_axis(axis) if axis is not None else None
if axis is None or axis == 0:
index_splits, index_chunk_shape = _calc_axis_splits(
left.index_value, right.index_value, left_index_chunks, right_index_chunks
)
else:
index_splits, index_chunk_shape = None, None

columns_splits, column_chunk_shape = _calc_axis_splits(
left.columns_value,
right.columns_value,
left_columns_chunks,
right_columns_chunks,
)
if axis is None or axis == 1:
columns_splits, column_chunk_shape = _calc_axis_splits(
left.columns_value,
right.columns_value,
left_columns_chunks,
right_columns_chunks,
)
else:
columns_splits, column_chunk_shape = None, None

splits = _MinMaxSplitInfo(index_splits, columns_splits)
out_chunk_shape = (
len(index_chunk_shape or list(itertools.chain(*index_splits._left_split))),
len(column_chunk_shape or list(itertools.chain(*columns_splits._left_split))),
out_left_chunk_shape = (
len(index_chunk_shape or list(itertools.chain(*index_splits._left_split)))
if index_splits is not None
else left.chunk_shape[0],
len(column_chunk_shape or list(itertools.chain(*columns_splits._left_split)))
if columns_splits is not None
else left.chunk_shape[1],
)
left_chunks = _gen_dataframe_chunks(splits, out_chunk_shape, 0, left)
right_chunks = _gen_dataframe_chunks(splits, out_chunk_shape, 1, right)
if _is_index_identical(left_index_chunks, right_index_chunks):
index_nsplits = left.nsplits[0]
if axis is None:
out_right_chunk_shape = out_left_chunk_shape
else:
index_nsplits = [np.nan for _ in range(out_chunk_shape[0])]
if _is_index_identical(left_columns_chunks, right_columns_chunks):
columns_nsplits = left.nsplits[1]
else:
columns_nsplits = [np.nan for _ in range(out_chunk_shape[1])]
out_right_chunk_shape = (
len(index_chunk_shape or list(itertools.chain(*index_splits._right_split)))
if index_splits is not None
else right.chunk_shape[0],
len(
column_chunk_shape
or list(itertools.chain(*columns_splits._right_split))
)
if columns_splits is not None
else right.chunk_shape[1],
)
left_chunks = _gen_dataframe_chunks(splits, out_left_chunk_shape, 0, left)
right_chunks = _gen_dataframe_chunks(splits, out_right_chunk_shape, 1, right)

index_nsplits = columns_nsplits = None
if axis is None or axis == 0:
if _is_index_identical(left_index_chunks, right_index_chunks):
index_nsplits = left.nsplits[0]
else:
index_nsplits = [np.nan for _ in range(out_left_chunk_shape[0])]
if axis is None or axis == 1:
if _is_index_identical(left_columns_chunks, right_columns_chunks):
columns_nsplits = left.nsplits[1]
else:
columns_nsplits = [np.nan for _ in range(out_left_chunk_shape[1])]

nsplits = [index_nsplits, columns_nsplits]

return nsplits, out_chunk_shape, left_chunks, right_chunks
out_chunk_shapes = (out_left_chunk_shape, out_right_chunk_shape)
return nsplits, out_chunk_shapes, left_chunks, right_chunks


def align_dataframe_series(left, right, axis="columns"):
if axis == "columns" or axis == 1:
axis = validate_axis(axis)
if axis == 1:
left_columns_chunks = [c.columns_value for c in left.cix[0, :]]
right_index_chunks = [c.index_value for c in right.chunks]
index_splits, chunk_shape = _calc_axis_splits(
Expand Down Expand Up @@ -941,7 +910,6 @@ def align_dataframe_series(left, right, axis="columns"):
index_nsplits = [np.nan for _ in range(out_chunk_shape[1])]
nsplits = [dummy_nsplits, index_nsplits]
else:
assert axis == "index" or axis == 0
left_index_chunks = [c.index_value for c in left.cix[:, 0]]
right_index_chunks = [c.index_value for c in right.chunks]
index_splits, index_chunk_shape = _calc_axis_splits(
Expand Down
67 changes: 15 additions & 52 deletions mars/dataframe/arithmetic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def _tile_both_dataframes(cls, op):
left, right = op.lhs, op.rhs
df = op.outputs[0]

nsplits, out_shape, left_chunks, right_chunks = align_dataframe_dataframe(
nsplits, out_shapes, left_chunks, right_chunks = align_dataframe_dataframe(
left, right
)
out_chunk_indexes = itertools.product(*(range(s) for s in out_shape))
out_chunk_indexes = itertools.product(*(range(s) for s in out_shapes[0]))

out_chunks = []
for idx, left_chunk, right_chunk in zip(
Expand Down Expand Up @@ -706,62 +706,25 @@ def rcall(self, x1, x2):


class DataFrameBinOp(DataFrameOperand, DataFrameBinOpMixin):
_axis = AnyField("axis")
_level = AnyField("level")
_fill_value = AnyField("fill_value")
_lhs = AnyField("lhs")
_rhs = AnyField("rhs")

def __init__(
self,
axis=None,
level=None,
fill_value=None,
output_types=None,
lhs=None,
rhs=None,
**kw,
):
super().__init__(
_axis=axis,
_level=level,
_fill_value=fill_value,
_output_types=output_types,
_lhs=lhs,
_rhs=rhs,
**kw,
)

@property
def axis(self):
return self._axis

@property
def level(self):
return self._level
axis = AnyField("axis", default=None)
level = AnyField("level", default=None)
fill_value = AnyField("fill_value", default=None)
lhs = AnyField("lhs")
rhs = AnyField("rhs")

@property
def fill_value(self):
return self._fill_value

@property
def lhs(self):
return self._lhs

@property
def rhs(self):
return self._rhs
def __init__(self, output_types=None, **kw):
super().__init__(_output_types=output_types, **kw)

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
if len(self._inputs) == 2:
self._lhs = self._inputs[0]
self._rhs = self._inputs[1]
self.lhs = self._inputs[0]
self.rhs = self._inputs[1]
else:
if isinstance(self._lhs, ENTITY_TYPE):
self._lhs = self._inputs[0]
elif pd.api.types.is_scalar(self._lhs):
self._rhs = self._inputs[0]
if isinstance(self.lhs, ENTITY_TYPE):
self.lhs = self._inputs[0]
elif pd.api.types.is_scalar(self.lhs):
self.rhs = self._inputs[0]


class DataFrameUnaryOpMixin(DataFrameOperandMixin):
Expand Down

0 comments on commit 60e19c3

Please sign in to comment.