Skip to content

Commit

Permalink
Avoid iterative tiling for df.loc[:, fields]
Browse files Browse the repository at this point in the history
  • Loading branch information
继盛 committed Feb 8, 2022
1 parent 56efd30 commit 2c6b354
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 25 deletions.
68 changes: 43 additions & 25 deletions mars/dataframe/indexing/index_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,12 @@ def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> Non
index_value = [tileable.index_value, tileable.columns_value][input_axis]

# check if chunks have unknown shape
check = False
if index_value.has_value():
# index_value has value,
check = True
elif self._slice_all(index_info.raw_index):
# if slice on all data
check = True

if check:
if any(np.isnan(ns) for ns in tileable.nsplits[input_axis]):
yield []
if (
not self._slice_all(index_info.raw_index)
and index_value.has_value()
and any(np.isnan(ns) for ns in tileable.nsplits[input_axis])
):
yield []

def set_chunk_index_info(
cls,
Expand Down Expand Up @@ -297,6 +292,27 @@ def set_chunk_index_info(

chunk_index_info.set(ChunkIndexAxisInfo(**kw))

def _process_slice_all_index(
self,
tileable: Tileable,
index_info: IndexInfo,
input_axis: int,
context: IndexHandlerContext,
) -> None:
index_to_info = context.chunk_index_to_info.copy()
for chunk_index, chunk_index_info in index_to_info.items():
i = chunk_index[input_axis]
size = tileable.nsplits[input_axis][i]
self.set_chunk_index_info(
context,
index_info,
chunk_index,
chunk_index_info,
i,
slice(None),
size,
)

def _process_has_value_index(
self,
tileable: Tileable,
Expand All @@ -306,17 +322,14 @@ def _process_has_value_index(
context: IndexHandlerContext,
) -> None:
pd_index = index_value.to_pandas()
if self._slice_all(index_info.raw_index):
slc = slice(None)
else:
# turn label-based slice into position-based slice
start, end = pd_index.slice_locs(
index_info.raw_index.start,
index_info.raw_index.stop,
index_info.raw_index.step,
kind="loc",
)
slc = slice(start, end, index_info.raw_index.step)
# turn label-based slice into position-based slice
start, end = pd_index.slice_locs(
index_info.raw_index.start,
index_info.raw_index.stop,
index_info.raw_index.step,
kind="loc",
)
slc = slice(start, end, index_info.raw_index.step)

cum_nsplit = [0] + np.cumsum(tileable.nsplits[index_info.input_axis]).tolist()
# split position-based slice into chunk slices
Expand Down Expand Up @@ -379,7 +392,9 @@ def process(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
else:
index_value = [tileable.index_value, tileable.columns_value][input_axis]

if index_value.has_value() or self._slice_all(index_info.raw_index):
if self._slice_all(index_info.raw_index):
self._process_slice_all_index(tileable, index_info, input_axis, context)
elif index_value.has_value():
self._process_has_value_index(
tileable, index_info, index_value, input_axis, context
)
Expand Down Expand Up @@ -829,10 +844,13 @@ def parse(self, raw_index, context: IndexHandlerContext) -> IndexInfo:
def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
tileable = context.tileable
op = context.op
if has_unknown_shape(tileable):
yield []

input_axis = index_info.input_axis

# check unknown shape
if any(np.isnan(s) for s in tileable.nsplits[input_axis]):
yield []

if tileable.ndim == 2:
index_value = [tileable.index_value, tileable.columns_value][input_axis]
else:
Expand Down
6 changes: 6 additions & 0 deletions mars/dataframe/indexing/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,12 @@ def test_dataframe_loc():
for loc_chunk, chunk in zip(tiled_loc_df.chunks, tiled_df.chunks):
assert loc_chunk.index_value.key == chunk.index_value.key

# test loc on filtered df
df2 = df[df["x"] < 1]
loc_df = df2.loc[:, ["y", "x"]]
tiled_loc_df = tile(loc_df)
assert len(tiled_loc_df.chunks) == 3


def test_loc_use_iloc():
raw = pd.DataFrame([[1, 3, 3], [4, 2, 6], [7, 8, 9]], columns=["x", "y", "z"])
Expand Down
4 changes: 4 additions & 0 deletions mars/dataframe/indexing/tests/test_indexing_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def test_loc_getitem(setup):
result = df.execute().fetch()
expected = raw2.loc[:, "b"]
pd.testing.assert_series_equal(result, expected)
df = df2.loc[:, ["b", "a"]]
result = df.execute().fetch()
expected = raw2.loc[:, ["b", "a"]]
pd.testing.assert_frame_equal(result, expected)

# 'b' is non-unique
df = df3.loc[:, "b"]
Expand Down

0 comments on commit 2c6b354

Please sign in to comment.