Skip to content

Commit

Permalink
FEAT-#3303: Add __getitem__ for Resampler (#3613)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexey Prutskov <alexey.prutskov@intel.com>
Signed-off-by: Maria Rubtsova <maria.rubtsova@intel.com>
  • Loading branch information
Rubtsowa and prutskov committed Nov 25, 2021
1 parent 5ff0669 commit 1a1edfd
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
36 changes: 36 additions & 0 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,6 +3127,42 @@ def __init__(
]
self.__groups = self.__get_groups(*self.resample_args)

def __getitem__(self, key):
"""
Get ``Resampler`` based on `key` columns of original dataframe.
Parameters
----------
key : str or list
String or list of selections.
Returns
-------
modin.pandas.BasePandasDataset
New ``Resampler`` based on `key` columns subset
of the original dataframe.
"""

def _get_new_resampler(key):
subset = self._dataframe[key]
resampler = type(self)(subset, *self.resample_args)
return resampler

from .series import Series

if isinstance(
key, (list, tuple, Series, pandas.Series, pandas.Index, np.ndarray)
):
if len(self._dataframe.columns.intersection(key)) != len(key):
missed_keys = list(set(key).difference(self._dataframe.columns))
raise KeyError(f"Columns not found: {str(sorted(missed_keys))[1:-1]}")
return _get_new_resampler(list(key))

if key not in self._dataframe:
raise KeyError(f"Column not found: {key}")

return _get_new_resampler(key)

def __get_groups(
self,
rule,
Expand Down
37 changes: 37 additions & 0 deletions modin/pandas/test/dataframe/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,43 @@ def test_resample_specific(rule, closed, label, on, level):
)


@pytest.mark.parametrize(
"columns",
[
"volume",
"date",
["volume"],
["price", "date"],
("volume",),
pandas.Series(["volume"]),
pandas.Index(["volume"]),
["volume", "volume", "volume"],
["volume", "price", "date"],
],
ids=[
"column",
"missed_column",
"list",
"missed_column",
"tuple",
"series",
"index",
"duplicate_column",
"missed_columns",
],
)
def test_resample_getitem(columns):
index = pandas.date_range("1/1/2013", periods=9, freq="T")
data = {
"price": range(9),
"volume": range(10, 19),
}
eval_general(
*create_test_dfs(data, index=index),
lambda df: df.resample("3T")[columns].mean(),
)


@pytest.mark.parametrize("data", test_data_values, ids=test_data_keys)
@pytest.mark.parametrize("index", ["default", "ndarray", "has_duplicates"])
@pytest.mark.parametrize("axis", [0, 1])
Expand Down

0 comments on commit 1a1edfd

Please sign in to comment.