Skip to content

Commit

Permalink
Patch pandas magic functions to allow reverse operands (#3155)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Jun 20, 2022
1 parent 8ca0d84 commit 03f0f12
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 133 deletions.
79 changes: 75 additions & 4 deletions mars/dataframe/arithmetic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@

import functools

from ..core import DATAFRAME_TYPE, is_build_mode
import pandas as pd

try:
from pandas.core.arraylike import OpsMixin as PdOpsMixin
except ImportError: # pragma: no cover
PdOpsMixin = None

from ..core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE, is_build_mode
from ..utils import wrap_notimplemented_exception
from ..ufunc.tensor import register_tensor_ufunc
from .abs import abs_, DataFrameAbs
Expand Down Expand Up @@ -61,7 +68,7 @@
from .exp import DataFrameExp
from .exp2 import DataFrameExp2
from .expm1 import DataFrameExpm1
from .dot import dot
from .dot import dot, rdot


def _wrap_eq():
Expand Down Expand Up @@ -95,9 +102,42 @@ def call(df, other, **kw):
return call


def _install():
from ..core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE
_reverse_magic_names = {
"eq": "eq",
"ne": "ne",
"lt": "ge",
"le": "gt",
"gt": "le",
"ge": "lt",
}


def _wrap_pandas_magics(cls, magic_name: str):
magic_func_name = f"__{magic_name}__"
magic_rfunc_name = _reverse_magic_names.get(magic_name, f"__r{magic_name}__")
try:
raw_method = getattr(cls, magic_func_name)
except AttributeError:
return

@functools.wraps(raw_method)
def wrapped(self, other):
if not isinstance(other, (DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE)):
return raw_method(self, other)

try:
val = getattr(other, magic_rfunc_name)(self)
except AttributeError: # pragma: no cover
return raw_method(self, other)

if val is NotImplemented: # pragma: no cover
return raw_method(self, other)
return val

setattr(cls, magic_func_name, wrapped)


def _install():
def _register_method(cls, name, func, wrapper=None):
if wrapper is None:

Expand Down Expand Up @@ -260,6 +300,7 @@ def call_series_no_fill(df, other, level=None, axis=0):
_register_bin_method(entity, "le", le)

setattr(entity, "__matmul__", dot)
setattr(entity, "__rmatmul__", rdot)
_register_method(entity, "dot", dot)

setattr(entity, "__and__", wrap_notimplemented_exception(bitand))
Expand All @@ -276,6 +317,36 @@ def call_series_no_fill(df, other, level=None, axis=0):
for entity in INDEX_TYPE:
setattr(entity, "__eq__", _wrap_eq())

if PdOpsMixin is not None and not hasattr(
pd, "_mars_df_arith_wrapped"
): # pragma: no branch
# wrap pandas magic functions to intercept reverse operands
for magic_name in [
"add",
"sub",
"mul",
"div",
"truediv",
"floordiv",
"mod",
"pow",
"and",
"or",
"xor",
"eq",
"ne",
"lt",
"le",
"gt",
"ge",
]:
_wrap_pandas_magics(PdOpsMixin, magic_name)

for pd_cls in (pd.DataFrame, pd.Series):
_wrap_pandas_magics(pd_cls, "matmul")

pd._mars_df_arith_wrapped = True


_install()
del _install
14 changes: 0 additions & 14 deletions mars/dataframe/arithmetic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ...core import ENTITY_TYPE, CHUNK_TYPE, recursive_tile
from ...serialization.serializables import AnyField
from ...tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE, ChunkData, Chunk
from ...tensor.datasource import tensor as astensor
from ...utils import classproperty, get_dtype
from ..align import (
align_series_series,
Expand All @@ -36,7 +35,6 @@
SERIES_CHUNK_TYPE,
is_chunk_meta_lazy,
)
from ..initializer import Series, DataFrame
from ..operands import DataFrameOperandMixin, DataFrameOperand
from ..ufunc.tensor import TensorUfuncMixin
from ..utils import (
Expand Down Expand Up @@ -639,18 +637,6 @@ def _new_chunks(self, inputs, kws=None, **kw):

return super()._new_chunks(inputs, shape=shape, kws=kws, **kw)

@staticmethod
def _process_input(x):
if isinstance(x, (DATAFRAME_TYPE, SERIES_TYPE)) or pd.api.types.is_scalar(x):
return x
elif isinstance(x, pd.Series):
return Series(x)
elif isinstance(x, pd.DataFrame):
return DataFrame(x)
elif isinstance(x, (list, tuple, np.ndarray, TENSOR_TYPE)):
return astensor(x)
raise NotImplementedError

def _check_inputs(self, x1, x2):
if isinstance(x1, TENSOR_TYPE) or isinstance(x2, TENSOR_TYPE):
tensor, other = (x1, x2) if isinstance(x1, TENSOR_TYPE) else (x2, x1)
Expand Down
33 changes: 16 additions & 17 deletions mars/dataframe/arithmetic/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,20 @@
class DataFrameDot(DataFrameOperand, DataFrameOperandMixin):
_op_type_ = OperandDef.DOT

_lhs = KeyField("lhs")
_rhs = AnyField("rhs")
lhs = KeyField("lhs")
rhs = AnyField("rhs")

def __init__(self, output_types=None, lhs=None, rhs=None, **kw):
super().__init__(_output_types=output_types, _lhs=lhs, _rhs=rhs, **kw)

@property
def lhs(self):
return self._lhs

@property
def rhs(self):
return self._rhs
def __init__(self, output_types=None, **kw):
super().__init__(_output_types=output_types, **kw)

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
self._lhs = self._inputs[0]
self._rhs = self._inputs[1]
self.lhs = self._inputs[0]
self.rhs = self._inputs[1]

def __call__(self, lhs, rhs):
lhs = self._process_input(lhs)
rhs = self._process_input(rhs)
if not isinstance(rhs, (DATAFRAME_TYPE, SERIES_TYPE)):
rhs = astensor(rhs)
test_rhs = rhs
Expand Down Expand Up @@ -171,9 +165,14 @@ def tile(cls, op):
return [tiled]


def dot(df_or_seris, other):
op = DataFrameDot(lhs=df_or_seris, rhs=other)
return op(df_or_seris, other)
def dot(df_or_series, other):
op = DataFrameDot(lhs=df_or_series, rhs=other)
return op(df_or_series, other)


def rdot(df_or_series, other):
op = DataFrameDot(lhs=other, rhs=df_or_series)
return op(other, df_or_series)


dot.__frame_doc__ = """
Expand Down
6 changes: 6 additions & 0 deletions mars/dataframe/arithmetic/tests/test_arithmetic_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,12 @@ def test_dataframe_and_scalar(setup, func_name, func_opts):
result6 = getattr(df, func_opts.rfunc_name)(1).execute().fetch()
pd.testing.assert_frame_equal(expected2, result6)

# test pandas series and dataframe
pdf2 = pd.DataFrame(np.random.rand(10, 10))
expected = func_opts.func(pdf2, pdf)
result = func_opts.func(pdf2, df).execute().fetch()
pd.testing.assert_frame_equal(expected, result)


@pytest.mark.parametrize("func_name, func_opts", binary_functions.items())
def test_with_shuffle_on_string_index(setup, func_name, func_opts):
Expand Down
10 changes: 8 additions & 2 deletions mars/dataframe/arithmetic/tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@


def test_comp(setup):
df1 = DataFrame(pd.DataFrame(np.random.rand(4, 3)))
df2 = DataFrame(pd.DataFrame(np.random.rand(4, 3)))
raw_df1 = pd.DataFrame(np.random.rand(4, 3))
raw_df2 = pd.DataFrame(np.random.rand(4, 3))
df1 = DataFrame(raw_df1)
df2 = DataFrame(raw_df2)

with enter_mode(build=True):
assert not df1.data == df2.data
Expand All @@ -43,6 +45,10 @@ def test_comp(setup):
pd.testing.assert_index_equal(
eq_df.index_value.to_pandas(), df1.index_value.to_pandas()
)
eq_df = op(raw_df1, df2)
pd.testing.assert_index_equal(
eq_df.index_value.to_pandas(), df1.index_value.to_pandas()
)

# index not identical
df3 = DataFrame(pd.DataFrame(np.random.rand(4, 3), index=[1, 2, 3, 4]))
Expand Down
6 changes: 6 additions & 0 deletions mars/dataframe/arithmetic/tests/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def test_dot_execution(setup):
expected = df1_raw @ df2_raw
pd.testing.assert_frame_equal(result, expected)

# test reversed @
r = df1_raw @ df2
result = r.execute().fetch()
expected = df1_raw @ df2_raw
pd.testing.assert_frame_equal(result, expected)

series1 = Series(s1_raw, chunk_size=5)

# df.dot(series)
Expand Down
13 changes: 13 additions & 0 deletions mars/dataframe/base/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from typing import Iterable

import pandas as pd
from pandas.api.types import (
is_datetime64_dtype,
is_datetime64tz_dtype,
is_timedelta64_dtype,
is_period_dtype,
)

from ...utils import adapt_mars_docstring
from .string_ import _string_method_to_handlers, SeriesStringMethod
Expand Down Expand Up @@ -218,6 +224,13 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"):

class DatetimeAccessor:
def __init__(self, series):
if (
not is_datetime64_dtype(series.dtype)
and not is_datetime64tz_dtype(series.dtype)
and not is_timedelta64_dtype(series.dtype)
and not is_period_dtype(series.dtype)
):
raise AttributeError("Can only use .dt accessor with datetimelike values")
self._series = series

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions mars/dataframe/base/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ def test_datetime_method():
)
assert c.shape == (2,) if i == 0 else (1,)

with pytest.raises(AttributeError):
_ = from_pandas_series(pd.Series([1])).dt
with pytest.raises(AttributeError):
_ = series.dt.non_exist

Expand Down
16 changes: 16 additions & 0 deletions mars/dataframe/operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import OrderedDict
from functools import reduce

import numpy as np
import pandas as pd

from ..core import FuseChunkData, FuseChunk, ENTITY_TYPE, OutputType
Expand All @@ -26,6 +27,7 @@
FuseChunkMixin,
)
from ..tensor.core import TENSOR_TYPE
from ..tensor.datasource import tensor as astensor
from ..tensor.operands import TensorOperandMixin
from ..utils import calc_nsplits
from .core import (
Expand Down Expand Up @@ -434,6 +436,20 @@ def _calc_series_index_params(cls, chunks):
def get_fuse_op_cls(self, _):
return DataFrameFuseChunk

@staticmethod
def _process_input(x):
from .initializer import DataFrame, Series

if isinstance(x, (DATAFRAME_TYPE, SERIES_TYPE)) or pd.api.types.is_scalar(x):
return x
elif isinstance(x, pd.Series):
return Series(x)
elif isinstance(x, pd.DataFrame):
return DataFrame(x)
elif isinstance(x, (list, tuple, np.ndarray, TENSOR_TYPE)):
return astensor(x)
raise NotImplementedError


DataFrameOperand = Operand

Expand Down

0 comments on commit 03f0f12

Please sign in to comment.