Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid iterative tiling for df.loc[:, fields] #2685

Merged
merged 1 commit into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions mars/dataframe/indexing/index_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,12 @@ def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> Non
index_value = [tileable.index_value, tileable.columns_value][input_axis]

# check if chunks have unknown shape
check = False
if index_value.has_value():
# index_value has value,
check = True
elif self._slice_all(index_info.raw_index):
# if slice on all data
check = True

if check:
if any(np.isnan(ns) for ns in tileable.nsplits[input_axis]):
yield []
if (
not self._slice_all(index_info.raw_index)
and index_value.has_value()
and any(np.isnan(ns) for ns in tileable.nsplits[input_axis])
): # pragma: no cover
yield []

def set_chunk_index_info(
cls,
Expand Down Expand Up @@ -297,6 +292,27 @@ def set_chunk_index_info(

chunk_index_info.set(ChunkIndexAxisInfo(**kw))

def _process_slice_all_index(
self,
tileable: Tileable,
index_info: IndexInfo,
input_axis: int,
context: IndexHandlerContext,
) -> None:
index_to_info = context.chunk_index_to_info.copy()
for chunk_index, chunk_index_info in index_to_info.items():
i = chunk_index[input_axis]
size = tileable.nsplits[input_axis][i]
self.set_chunk_index_info(
context,
index_info,
chunk_index,
chunk_index_info,
i,
slice(None),
size,
)

def _process_has_value_index(
self,
tileable: Tileable,
Expand All @@ -306,17 +322,14 @@ def _process_has_value_index(
context: IndexHandlerContext,
) -> None:
pd_index = index_value.to_pandas()
if self._slice_all(index_info.raw_index):
slc = slice(None)
else:
# turn label-based slice into position-based slice
start, end = pd_index.slice_locs(
index_info.raw_index.start,
index_info.raw_index.stop,
index_info.raw_index.step,
kind="loc",
)
slc = slice(start, end, index_info.raw_index.step)
# turn label-based slice into position-based slice
start, end = pd_index.slice_locs(
index_info.raw_index.start,
index_info.raw_index.stop,
index_info.raw_index.step,
kind="loc",
)
slc = slice(start, end, index_info.raw_index.step)

cum_nsplit = [0] + np.cumsum(tileable.nsplits[index_info.input_axis]).tolist()
# split position-based slice into chunk slices
Expand Down Expand Up @@ -379,7 +392,9 @@ def process(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
else:
index_value = [tileable.index_value, tileable.columns_value][input_axis]

if index_value.has_value() or self._slice_all(index_info.raw_index):
if self._slice_all(index_info.raw_index):
self._process_slice_all_index(tileable, index_info, input_axis, context)
elif index_value.has_value():
self._process_has_value_index(
tileable, index_info, index_value, input_axis, context
)
Expand Down Expand Up @@ -829,10 +844,13 @@ def parse(self, raw_index, context: IndexHandlerContext) -> IndexInfo:
def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
tileable = context.tileable
op = context.op
if has_unknown_shape(tileable):
yield []

input_axis = index_info.input_axis

# check unknown shape
if any(np.isnan(s) for s in tileable.nsplits[input_axis]):
yield []

if tileable.ndim == 2:
index_value = [tileable.index_value, tileable.columns_value][input_axis]
else:
Expand Down
6 changes: 6 additions & 0 deletions mars/dataframe/indexing/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,12 @@ def test_dataframe_loc():
for loc_chunk, chunk in zip(tiled_loc_df.chunks, tiled_df.chunks):
assert loc_chunk.index_value.key == chunk.index_value.key

# test loc on filtered df
df2 = df[df["x"] < 1]
loc_df = df2.loc[:, ["y", "x"]]
tiled_loc_df = tile(loc_df)
assert len(tiled_loc_df.chunks) == 3


def test_loc_use_iloc():
raw = pd.DataFrame([[1, 3, 3], [4, 2, 6], [7, 8, 9]], columns=["x", "y", "z"])
Expand Down
9 changes: 9 additions & 0 deletions mars/dataframe/indexing/tests/test_indexing_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def test_loc_getitem(setup):
result = df.execute().fetch()
expected = raw2.loc[:, "b"]
pd.testing.assert_series_equal(result, expected)
df = df2.loc[:, ["b", "a"]]
result = df.execute().fetch()
expected = raw2.loc[:, ["b", "a"]]
pd.testing.assert_frame_equal(result, expected)

# 'b' is non-unique
df = df3.loc[:, "b"]
Expand All @@ -336,6 +340,11 @@ def test_loc_getitem(setup):
result = df.execute().fetch()
expected = raw2.loc[[3, 0, 1], ["c", "a", "d"]]
pd.testing.assert_frame_equal(result, expected)
df = df2[df2["a"] < 10]
df = df.loc[[3, 0, 1], ["c", "a", "d"]]
result = df.execute().fetch()
expected = raw2.loc[[3, 0, 1], ["c", "a", "d"]]
pd.testing.assert_frame_equal(result, expected)

# label-based fancy index, asc sorted
df = df2.loc[[0, 1, 3], ["a", "c", "d"]]
Expand Down