Skip to content

Commit

Permalink
Fix loc with MultiIndex (#508)
Browse files Browse the repository at this point in the history
* Fix loc with MultiIndex

* Add tests and fix cases

* Add test for MultiIndex in index

* Remove dead code + lint

* Add comments
  • Loading branch information
devin-petersohn committed Mar 22, 2019
1 parent 5e70be5 commit bf15103
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 23 deletions.
16 changes: 10 additions & 6 deletions modin/data_management/query_compiler/pandas_query_compiler.py
Expand Up @@ -2775,16 +2775,20 @@ def view(self, index=None, columns=None):
)

def squeeze(self, ndim=0, axis=None):
to_squeeze = self.data.to_pandas()
to_squeeze = self.to_pandas()
# This is the case for 1xN or Nx1 DF - Need to call squeeze
if ndim == 1:
if axis is None:
axis = 0 if self.data.shape[1] > 1 else 1
squeezed = pandas.Series(to_squeeze.squeeze(axis))
scaler_axis = self.columns if axis else self.index
non_scaler_axis = self.index if axis else self.columns
squeezed.name = scaler_axis[0]
squeezed.index = non_scaler_axis
squeezed = pandas.Series(to_squeeze.squeeze())
# In the case of `MultiIndex`, we already have the correct index and naming
# because we are going from pandas above. This step is to correct the
# `Series` to have the correct name and index.
if not isinstance(squeezed.index, pandas.MultiIndex):
scaler_axis = self.columns if axis else self.index
non_scaler_axis = self.index if axis else self.columns
squeezed.name = scaler_axis[0]
squeezed.index = non_scaler_axis
return squeezed
# This is the case for a 1x1 DF - We don't need to squeeze
else:
Expand Down
50 changes: 33 additions & 17 deletions modin/pandas/indexing.py
Expand Up @@ -230,8 +230,26 @@ def __getitem__(self, key):
row_loc, col_loc, ndim, self.row_scaler, self.col_scaler = _parse_tuple(key)
self._handle_enlargement(row_loc, col_loc)
row_lookup, col_lookup = self._compute_lookup(row_loc, col_loc)
ndim = self._expand_dim(row_lookup, col_lookup, ndim)
ndim = (0 if len(row_lookup) == 1 else 1) + (0 if len(col_lookup) == 1 else 1)
result = super(_LocIndexer, self).__getitem__(row_lookup, col_lookup, ndim)
# Pandas drops the levels that are in the `loc`, so we have to as well.
if hasattr(result, "index") and isinstance(result.index, pandas.MultiIndex):
if (
isinstance(result, pandas.Series)
and not isinstance(col_loc, slice)
and all(
col_loc[i] in result.index.levels[i] for i in range(len(col_loc))
)
):
result.index = result.index.droplevel(list(range(len(col_loc))))
elif all(row_loc[i] in result.index.levels[i] for i in range(len(row_loc))):
result.index = result.index.droplevel(list(range(len(row_loc))))
if (
hasattr(result, "columns")
and isinstance(result.columns, pandas.MultiIndex)
and all(col_loc[i] in result.columns.levels[i] for i in range(len(col_loc)))
):
result.columns = result.columns.droplevel(list(range(len(col_loc))))
return result

def __setitem__(self, key, item):
Expand Down Expand Up @@ -293,28 +311,26 @@ def _compute_enlarge_labels(self, locator, base_index):
)
return nan_labels

def _expand_dim(self, row_lookup, col_lookup, ndim):
"""Expand the dimension if necessary.
This method is for cases like duplicate labels.
"""
many_rows = len(row_lookup) > 1
many_cols = len(col_lookup) > 1

if ndim == 0 and (many_rows or many_cols):
ndim = 1
if ndim == 1 and (many_rows and many_cols):
ndim = 2
return ndim

def _compute_lookup(self, row_loc, col_loc) -> Tuple[pandas.Index, pandas.Index]:
if isinstance(row_loc, list) and len(row_loc) == 1:
if is_list_like(row_loc) and len(row_loc) == 1:
if (
isinstance(self.qc.index.values[0], np.datetime64)
and type(row_loc[0]) != np.datetime64
):
row_loc = [pandas.to_datetime(row_loc[0])]
row_lookup = self.qc.index.to_series().loc[row_loc].index
col_lookup = self.qc.columns.to_series().loc[col_loc].index

if isinstance(row_loc, slice):
row_lookup = self.qc.index.to_series().loc[row_loc].values
elif isinstance(self.qc.index, pandas.MultiIndex):
row_lookup = self.qc.index[self.qc.index.get_locs(row_loc)]
else:
row_lookup = self.qc.index[self.qc.index.get_indexer_for(row_loc)]
if isinstance(col_loc, slice):
col_lookup = self.qc.columns.to_series().loc[col_loc].values
elif isinstance(self.qc.columns, pandas.MultiIndex):
col_lookup = self.qc.columns[self.qc.columns.get_locs(col_loc)]
else:
col_lookup = self.qc.columns[self.qc.columns.get_indexer_for(col_loc)]
return row_lookup, col_lookup


Expand Down
11 changes: 11 additions & 0 deletions modin/pandas/series.py
Expand Up @@ -41,6 +41,14 @@ def __init__(self, series, parent_df, loc):
self.parent_df = parent_df
self._loc = loc

def _get_index(self):
return self.series.index

def _set_index(self, index):
self.series.index = index

index = property(_get_index, _set_index)

def __repr__(self):
return repr(self.series)

Expand Down Expand Up @@ -176,6 +184,9 @@ def __getattribute__(self, item):
"__arithmetic_op__",
"__comparisons__",
"__class__",
"index",
"_get_index",
"_set_index",
]
if item not in default_behaviors:
method = self.series.__getattribute__(item)
Expand Down
10 changes: 10 additions & 0 deletions modin/pandas/test/data/blah.csv

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions modin/pandas/test/test_dataframe.py
Expand Up @@ -2831,6 +2831,60 @@ def test_loc(request, data):
df_equals(modin_df_copy, pandas_df_copy)


def test_loc_multi_index():
modin_df = pd.read_csv(
"modin/pandas/test/data/blah.csv", header=[0, 1, 2, 3], index_col=0
)
pandas_df = pandas.read_csv(
"modin/pandas/test/data/blah.csv", header=[0, 1, 2, 3], index_col=0
)

df_equals(modin_df.loc[1], pandas_df.loc[1])
assert modin_df.loc[1, "Presidents"].equals(pandas_df.loc[1, "Presidents"])
assert modin_df.loc[1, ("Presidents", "Pure mentions")].equals(
pandas_df.loc[1, ("Presidents", "Pure mentions")]
)
assert (
modin_df.loc[1, ("Presidents", "Pure mentions", "IND", "all")]
== pandas_df.loc[1, ("Presidents", "Pure mentions", "IND", "all")]
)
df_equals(modin_df.loc[(1, 2), "Presidents"], pandas_df.loc[(1, 2), "Presidents"])

tuples = [
("bar", "one"),
("bar", "two"),
("bar", "three"),
("bar", "four"),
("baz", "one"),
("baz", "two"),
("baz", "three"),
("baz", "four"),
("foo", "one"),
("foo", "two"),
("foo", "three"),
("foo", "four"),
("qux", "one"),
("qux", "two"),
("qux", "three"),
("qux", "four"),
]

modin_index = pd.MultiIndex.from_tuples(tuples, names=["first", "second"])
pandas_index = pandas.MultiIndex.from_tuples(tuples, names=["first", "second"])
frame_data = np.random.randint(0, 100, size=(16, 100))
modin_df = pd.DataFrame(
frame_data, index=modin_index, columns=["col{}".format(i) for i in range(100)]
)
pandas_df = pandas.DataFrame(
frame_data, index=pandas_index, columns=["col{}".format(i) for i in range(100)]
)
assert modin_df.loc["bar", "col1"].equals(pandas_df.loc["bar", "col1"])
assert modin_df.loc[("bar", "one"), "col1"] == pandas_df.loc[("bar", "one"), "col1"]
df_equals(
modin_df.loc["bar", ("col1", "col2")], pandas_df.loc["bar", ("col1", "col2")]
)


def test_lookup():
data = test_data_values[0]
with pytest.warns(UserWarning):
Expand Down

0 comments on commit bf15103

Please sign in to comment.