Skip to content

Commit

Permalink
FIX-#3219: delegate 'apply' result type inference to backend (#3746)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev committed Jan 13, 2022
1 parent 0315823 commit ac17ca1
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 83 deletions.
26 changes: 23 additions & 3 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2077,7 +2077,7 @@ def drop(self, index=None, columns=None):
# UDF (apply and agg) methods
# There is a wide range of behaviors that are supported, so a lot of the
# logic can get a bit convoluted.
def apply(self, func, axis, *args, **kwargs):
def apply(self, func, axis, raw=False, result_type=None, *args, **kwargs):
"""
Apply passed function across given axis.
Expand All @@ -2088,6 +2088,17 @@ def apply(self, func, axis, *args, **kwargs):
axis : {0, 1}
Target axis to apply the function along.
0 is for index, 1 is for columns.
raw : bool, default: False
Whether to pass a high-level Series object (False) or a raw representation
of the data (True).
result_type : {"expand", "reduce", "broadcast", None}, default: None
Determines how to treat list-like return type of the `func` (works only if
a single function was passed):
- "expand": expand list-like result into columns.
- "reduce": keep result into a single cell (opposite of "expand").
- "broadcast": broadcast result to original data shape (overwrite the
existing column/row with the function result).
- None: use "expand" strategy if Series is returned, "reduce" otherwise.
*args : iterable
Positional arguments to pass to `func`.
**kwargs : dict
Expand All @@ -2099,13 +2110,22 @@ def apply(self, func, axis, *args, **kwargs):
QueryCompiler that contains the results of execution and is built by
the following rules:
- Labels of specified axis are the passed functions names.
- Index of the specified axis contains: the names of the passed functions if multiple
functions are passed, otherwise: indices of the `func` result if "expand" strategy
is used, indices of the original frame if "broadcast" strategy is used, a single
label "__reduced__" if "reduce" strategy is used.
- Labels of the opposite axis are preserved.
- Each element is the result of execution of `func` against
corresponding row/column.
"""
return DataFrameDefault.register(pandas.DataFrame.apply)(
self, func=func, axis=axis, *args, **kwargs
self,
func=func,
axis=axis,
raw=raw,
result_type=result_type,
*args,
**kwargs,
)

def explode(self, column):
Expand Down
45 changes: 6 additions & 39 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,47 +2344,12 @@ def apply(self, func, axis, *args, **kwargs):
# convert it to pandas
args = try_cast_to_pandas(args)
kwargs = try_cast_to_pandas(kwargs)
if isinstance(func, str):
return self._apply_text_func_elementwise(func, axis, *args, **kwargs)
elif callable(func):
return self._callable_func(func, axis, *args, **kwargs)
elif isinstance(func, dict):
if isinstance(func, dict):
return self._dict_func(func, axis, *args, **kwargs)
elif is_list_like(func):
return self._list_like_func(func, axis, *args, **kwargs)
else:
pass

# FIXME: `_apply_text_func_elementwise` duplicates most of the logic of `_callable_func`,
# these methods should be combined.
def _apply_text_func_elementwise(self, func, axis, *args, **kwargs):
"""
Apply passed string function to each row/column.
Parameters
----------
func : str
Function name to apply.
axis : {0, 1}
Target axis to apply function along. 0 means apply to columns,
1 means apply to rows.
*args : args
Arguments to pass to the specified function.
**kwargs : kwargs
Arguments to pass to the specified function.
Returns
-------
PandasQueryCompiler
New QueryCompiler containing the results of passed function
for each row/column.
"""
assert isinstance(func, str)
kwargs["axis"] = axis
new_modin_frame = self._modin_frame.apply_full_axis(
axis, lambda df: df.apply(func, *args, **kwargs)
)
return self.__constructor__(new_modin_frame)
return self._callable_func(func, axis, *args, **kwargs)

def _dict_func(self, func, axis, *args, **kwargs):
"""
Expand Down Expand Up @@ -2469,7 +2434,7 @@ def _callable_func(self, func, axis, *args, **kwargs):
Parameters
----------
func : callable
func : callable or str
Function to apply.
axis : {0, 1}
Target axis to apply function along. 0 means apply to columns,
Expand All @@ -2485,7 +2450,9 @@ def _callable_func(self, func, axis, *args, **kwargs):
New QueryCompiler containing the results of passed function
for each row/column.
"""
func = wrap_udf_function(func)
if callable(func):
func = wrap_udf_function(func)

new_modin_frame = self._modin_frame.apply_full_axis(
axis, lambda df: df.apply(func, axis=axis, *args, **kwargs)
)
Expand Down
51 changes: 17 additions & 34 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,42 +363,25 @@ def apply(
func, axis=axis, raw=raw, result_type=result_type, args=args, **kwargs
)
if not isinstance(query_compiler, type(self._query_compiler)):
# A scalar was returned
return query_compiler
# This is the simplest way to determine the return type, but there are checks
# in pandas that verify that some results are created. This is a challenge for
# empty DataFrames, but fortunately they only happen when the `func` type is
# a list or a dictionary, which means that the return type won't change from
# type(self), so we catch that error and use `type(self).__name__` for the return
# type.
try:
if axis == 0:
init_kwargs = {"index": self.index}
else:
init_kwargs = {"columns": self.columns}
return_type = type(
getattr(pandas, type(self).__name__)(**init_kwargs).apply(
func,
axis=axis,
raw=raw,
result_type=result_type,
args=args,
**kwargs,
)
).__name__
except Exception:
return_type = type(self).__name__
if return_type not in ["DataFrame", "Series"]:
return query_compiler.to_pandas().squeeze()

if result_type == "reduce":
output_type = Series
elif result_type == "broadcast":
output_type = DataFrame
# the 'else' branch also handles 'result_type == "expand"' since it makes the output type
# depend on the `func` result (Series for a scalar, DataFrame for list-like)
else:
result = getattr(sys.modules[self.__module__], return_type)(
query_compiler=query_compiler
)
if isinstance(result, Series):
if axis == 0 and result.name == self.index[0] or result.name == 0:
result.name = None
elif axis == 1 and result.name == self.columns[0] or result.name == 0:
result.name = None
return result
reduced_index = pandas.Index(["__reduced__"])
if query_compiler.get_axis(axis).equals(
reduced_index
) or query_compiler.get_axis(axis ^ 1).equals(reduced_index):
output_type = Series
else:
output_type = DataFrame

return output_type(query_compiler=query_compiler)

def groupby(
self,
Expand Down
3 changes: 1 addition & 2 deletions modin/pandas/test/dataframe/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def add(a, b, c):
def test_apply_udf(data, func):
eval_general(
*create_test_dfs(data),
lambda df, *args, **kwargs: df.apply(*args, **kwargs),
func=func,
lambda df, *args, **kwargs: df.apply(func, *args, **kwargs),
other=lambda df: df,
)

Expand Down
14 changes: 9 additions & 5 deletions modin/pandas/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@
"str": str,
"sum mean": ["sum", "mean"],
"sum df sum": ["sum", lambda df: df.sum()],
# The case verifies that returning a scalar that is based on a frame's data doesn't cause a problem
"sum of certain elements": lambda axis: (
axis.iloc[0] + axis.iloc[-1] if isinstance(axis, pandas.Series) else axis + axis
),
"should raise TypeError": 1,
}
agg_func_keys = list(agg_func.keys())
Expand All @@ -311,13 +315,13 @@
numeric_agg_funcs = ["sum mean", "sum sum", "sum df sum"]

udf_func = {
"return self": lambda df: lambda x, *args, **kwargs: type(x)(x.values),
"change index": lambda df: lambda x, *args, **kwargs: pandas.Series(
"return self": lambda x, *args, **kwargs: type(x)(x.values),
"change index": lambda x, *args, **kwargs: pandas.Series(
x.values, index=np.arange(-1, len(x.index) - 1)
),
"return none": lambda df: lambda x, *args, **kwargs: None,
"return empty": lambda df: lambda x, *args, **kwargs: pandas.Series(),
"access self": lambda df: lambda x, other, *args, **kwargs: pandas.Series(
"return none": lambda x, *args, **kwargs: None,
"return empty": lambda x, *args, **kwargs: pandas.Series(),
"access self": lambda x, other, *args, **kwargs: pandas.Series(
x.values, index=other.index
),
}
Expand Down

0 comments on commit ac17ca1

Please sign in to comment.