Skip to content

Commit

Permalink
Fix df.loc[:] to make sure same index_value key generated (#2643)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored Jan 24, 2022
1 parent 3c0a4ca commit 0df26c5
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 7 deletions.
7 changes: 5 additions & 2 deletions mars/dataframe/base/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def compute_rechunk(a, chunk_size):
calc_sliced_size(s, chunk_slice[0]) for s in old_chunk.shape
)
new_index_value = indexing_index_value(
old_chunk.index_value, chunk_slice[0]
old_chunk.index_value, chunk_slice[0], rechunk=True
)
if is_dataframe:
new_columns_value = indexing_index_value(
old_chunk.columns_value, chunk_slice[1], store_data=True
old_chunk.columns_value,
chunk_slice[1],
store_data=True,
rechunk=True,
)
merge_chunk_op = DataFrameIlocGetItem(
list(chunk_slice),
Expand Down
10 changes: 8 additions & 2 deletions mars/dataframe/indexing/loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ...serialization.serializables import KeyField, ListField
from ...tensor.datasource import asarray
from ...tensor.utils import calc_sliced_size, filter_inputs
from ...utils import lazy_import
from ...utils import lazy_import, is_full_slice
from ..core import IndexValue, DATAFRAME_TYPE
from ..operands import DataFrameOperand, DataFrameOperandMixin
from ..utils import parse_index
Expand Down Expand Up @@ -154,7 +154,13 @@ def _calc_slice_param(
axis: int,
) -> Dict:
param = dict()
if input_index_value.has_value():
if is_full_slice(index):
# full slice on this axis
param["shape"] = inp.shape[axis]
param["index_value"] = input_index_value
if axis == 1:
param["dtypes"] = inp.dtypes
elif input_index_value.has_value():
start, end = pd_index.slice_locs(
index.start, index.stop, index.step, kind="loc"
)
Expand Down
2 changes: 2 additions & 0 deletions mars/dataframe/indexing/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_iloc_getitem():
df4 = tile(df4)
assert isinstance(df4, DATAFRAME_TYPE)
assert isinstance(df4.op, DataFrameIlocGetItem)
assert df4.index_value.key == df2.index_value.key
assert df4.shape == (3, 1)
assert df4.chunk_shape == (2, 1)
assert df4.chunks[0].shape == (2, 1)
Expand Down Expand Up @@ -479,6 +480,7 @@ def test_dataframe_loc():
df2.index_value.to_pandas(), df.index_value.to_pandas()
)
assert df2.name == "y"
assert df2.index_value.key == df.index_value.key

df2 = tile(df2)
assert len(df2.chunks) == 2
Expand Down
10 changes: 7 additions & 3 deletions mars/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..core import Entity, ExecutableTuple
from ..lib.mmh3 import hash as mmh_hash
from ..tensor.utils import dictify_chunk_size, normalize_chunk_sizes
from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder
from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder, is_full_slice

try:
import pyarrow as pa
Expand Down Expand Up @@ -804,9 +804,13 @@ def filter_index_value(index_value, min_max, store_data=False):
return parse_index(pd_index[f], store_data=store_data)


def indexing_index_value(index_value, indexes, store_data=False):
def indexing_index_value(index_value, indexes, store_data=False, rechunk=False):
pd_index = index_value.to_pandas()
if not index_value.has_value():
# when rechunk is True, the output index shall be treated
# different from the input one
if not rechunk and isinstance(indexes, slice) and is_full_slice(indexes):
return index_value
elif not index_value.has_value():
new_index_value = parse_index(pd_index, indexes, store_data=store_data)
new_index_value._index_value._min_val = index_value.min_val
new_index_value._index_value._min_val_close = index_value.min_val_close
Expand Down
10 changes: 10 additions & 0 deletions mars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,3 +1519,13 @@ def flatten_dict_to_nested_dict(flatten_dict: Dict, sep=".") -> Dict:
else:
sub_nested_dict = sub_nested_dict[sub_key]
return nested_dict


def is_full_slice(slc: Any) -> bool:
"""Check if the input is a full slice ((:) or (0:))"""
return (
isinstance(slc, slice)
and (slc.start == 0 or slc.start is None)
and slc.stop is None
and slc.step is None
)

0 comments on commit 0df26c5

Please sign in to comment.