Skip to content

Commit

Permalink
Always use column_index. (#648)
Browse files Browse the repository at this point in the history
Now that we introduced multi-index columns, we should use `column_index` rather than `data_columns` to be consistent access between single and multi index.
We still have a ways to go, so this includes a workaround to work with the `data_columns`-based approach.
I will address them in the separate PRs.
  • Loading branch information
ueshin authored and HyukjinKwon committed Aug 19, 2019
1 parent f0f1859 commit e694826
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 89 deletions.
86 changes: 35 additions & 51 deletions databricks/koalas/frame.py
Expand Up @@ -49,7 +49,7 @@
from databricks.koalas.internal import _InternalFrame, IndexMap
from databricks.koalas.missing.frame import _MissingPandasLikeDataFrame
from databricks.koalas.ml import corr
from databricks.koalas.utils import scol_for
from databricks.koalas.utils import column_index_level, scol_for
from databricks.koalas.typedef import as_spark_type

# These regular expression patterns are complied and defined here to avoid to compile the same
Expand Down Expand Up @@ -330,7 +330,7 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=False):
if axis in ('index', 0, None):
exprs = []
num_args = len(signature(sfun).parameters)
for col in self._internal.data_columns:
for col, idx in zip(self._internal.data_columns, self._internal.column_index):
col_sdf = self._internal.scol_for(col)
col_type = self._internal.spark_type_for(col)

Expand All @@ -351,12 +351,12 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=False):
assert num_args == 2
# Pass in both the column and its data type if sfun accepts two args
col_sdf = sfun(col_sdf, col_type)
exprs.append(col_sdf.alias(col))
exprs.append(col_sdf.alias(str(idx) if len(idx) > 1 else idx[0]))

sdf = self._sdf.select(*exprs)
pdf = sdf.toPandas()

if self._internal.column_index is not None:
if self._internal.column_index_level > 1:
pdf.columns = pd.MultiIndex.from_tuples(self._internal.column_index)

assert len(pdf) == 1, (sdf, pdf)
Expand Down Expand Up @@ -1923,7 +1923,7 @@ def rename(index):
index_map=index_map,
column_index=None)

if self._internal.column_index is not None:
if self._internal.column_index_level > 1:
column_depth = len(self._internal.column_index[0])
if col_level >= column_depth:
raise IndexError('Too many levels: Index has only {} levels, not {}'
Expand Down Expand Up @@ -3788,10 +3788,10 @@ def pivot(self, index=None, columns=None, values=None):
@property
def columns(self):
"""The column labels of the DataFrame."""
if self._internal.column_index is not None:
if self._internal.column_index_level > 1:
columns = pd.MultiIndex.from_tuples(self._internal.column_index)
else:
columns = pd.Index(self._internal.data_columns)
columns = pd.Index([idx[0] for idx in self._internal.column_index])
if self._internal.column_index_names is not None:
columns.names = self._internal.column_index_names
return columns
Expand Down Expand Up @@ -5566,7 +5566,9 @@ def _cum(self, func, skipna: bool):

sdf = self._sdf.select(
self._internal.index_scols + [c._scol for c in applied])
internal = self._internal.copy(sdf=sdf, data_columns=[c.name for c in applied])
# FIXME(ueshin): no need to specify `column_index`.
internal = self._internal.copy(sdf=sdf, data_columns=[c.name for c in applied],
column_index=self._internal.column_index)
return DataFrame(internal)

# TODO: implements 'keep' parameters
Expand Down Expand Up @@ -6336,61 +6338,49 @@ def _get_from_multiindex_column(self, key):
if len(columns) == 0:
raise KeyError(k)
recursive = False
if all(len(idx) == 0 or idx[0] == '' for _, idx in columns):
# If idx is empty or the head is '', drill down recursively.
if all(len(idx) > 0 and idx[0] == '' for _, idx in columns):
# If the head is '', drill down recursively.
recursive = True
for i, (col, idx) in enumerate(columns):
columns[i] = (col, tuple([str(key), *idx[1:]]))

column_index_names = None
if self._internal.column_index_names is not None:
# Manage column index names
column_index_level = set(len(idx) for _, idx in columns)
assert len(column_index_level) == 1
column_index_level = list(column_index_level)[0]
column_index_names = self._internal.column_index_names[-column_index_level:]
if all(len(idx) == 1 for _, idx in columns):
# If len(idx) == 1, then the result is not MultiIndex anymore
sdf = self._sdf.select(self._internal.index_scols +
[self._internal.scol_for(col).alias(idx[0])
for col, idx in columns])
kdf_or_ser = DataFrame(self._internal.copy(
sdf=sdf,
data_columns=[idx[0] for _, idx in columns],
column_index=None,
column_index_names=column_index_names))
level = column_index_level([idx for _, idx in columns])
column_index_names = self._internal.column_index_names[-level:]

if all(len(idx) == 0 for _, idx in columns):
try:
cols = set(col for col, _ in columns)
assert len(cols) == 1
kdf_or_ser = \
Series(self._internal.copy(scol=self._internal.scol_for(list(cols)[0])),
anchor=self)
except AnalysisException:
raise KeyError(key)
else:
# Otherwise, the result is still MultiIndex and need to manage column_index.
sdf = self._sdf.select(self._internal.index_scols +
[self._internal.scol_for(col) for col, _ in columns])
kdf_or_ser = DataFrame(self._internal.copy(
sdf=sdf,
data_columns=[col for col, _ in columns],
column_index=[idx for _, idx in columns],
column_index_names=column_index_names))

if recursive:
kdf_or_ser = kdf_or_ser._pd_getitem(str(key))
kdf_or_ser = kdf_or_ser._get_from_multiindex_column((str(key),))
if isinstance(kdf_or_ser, Series):
kdf_or_ser.name = str(key)
name = str(key) if len(key) > 1 else key[0]
if kdf_or_ser.name != name:
kdf_or_ser.name = name
return kdf_or_ser

def _pd_getitem(self, key):
from databricks.koalas.series import Series
if key is None:
raise KeyError("none key")
if isinstance(key, str):
if self._internal.column_index is not None:
return self._get_from_multiindex_column((key,))
else:
try:
return Series(self._internal.copy(scol=self._internal.scol_for(key)),
anchor=self)
except AnalysisException:
raise KeyError(key)
return self._get_from_multiindex_column((key,))
if isinstance(key, tuple):
if self._internal.column_index is not None:
return self._get_from_multiindex_column(key)
else:
raise NotImplementedError(key)
return self._get_from_multiindex_column(key)
elif np.isscalar(key):
raise NotImplementedError(key)
elif isinstance(key, slice):
Expand Down Expand Up @@ -6478,7 +6468,6 @@ def assign_columns(kdf, this_columns, that_columns):
self._internal = kdf._internal

def __getattr__(self, key: str) -> Any:
from databricks.koalas.series import Series
if key.startswith("__") or key.startswith("_pandas_") or key.startswith("_spark_"):
raise AttributeError(key)
if hasattr(_MissingPandasLikeDataFrame, key):
Expand All @@ -6488,16 +6477,11 @@ def __getattr__(self, key: str) -> Any:
else:
return partial(property_or_func, self)

if self._internal.column_index is not None:
try:
return self._get_from_multiindex_column((key,))
except KeyError:
raise AttributeError(
"'%s' object has no attribute '%s'" % (self.__class__.__name__, key))
if key not in self._internal.data_columns:
try:
return self._get_from_multiindex_column((key,))
except KeyError:
raise AttributeError(
"'%s' object has no attribute '%s'" % (self.__class__.__name__, key))
return Series(self._internal.copy(scol=self._internal.scol_for(key)), anchor=self)

def __len__(self):
return self._sdf.count()
Expand Down
2 changes: 1 addition & 1 deletion databricks/koalas/indexes.py
Expand Up @@ -113,7 +113,7 @@ def to_pandas(self) -> pd.Index:
internal = self._kdf._internal.copy(
sdf=sdf,
index_map=[(sdf.schema[0].name, self._kdf._internal.index_names[0])],
data_columns=[])
data_columns=[], column_index=[], column_index_names=None)
return DataFrame(internal).to_pandas().index

toPandas = to_pandas
Expand Down
43 changes: 25 additions & 18 deletions databricks/koalas/indexing.py
Expand Up @@ -27,6 +27,7 @@
from pyspark.sql.utils import AnalysisException

from databricks.koalas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError
from databricks.koalas.utils import column_index_level


def _make_col(c):
Expand Down Expand Up @@ -145,7 +146,7 @@ def __getitem__(self, key):
raise ValueError("'.at' only supports indices with level 1 right now")

if self._ks is None:
if self._kdf._internal.column_index is not None:
if self._kdf._internal.column_index_level > 1:
column = dict(zip(self._kdf._internal.column_index,
self._kdf._internal.data_columns)).get(key[1], None)
if column is None:
Expand Down Expand Up @@ -416,10 +417,8 @@ def raiseNotImplemented(description):
# make cols_sel a 1-tuple of string if a single string
column_index = self._kdf._internal.column_index
if isinstance(cols_sel, str):
if column_index is not None:
return self[rows_sel, [cols_sel]]._get_from_multiindex_column((cols_sel,))
else:
cols_sel = _make_col(cols_sel)
kdf = DataFrame(self._kdf._internal.copy(sdf=sdf))
return kdf._get_from_multiindex_column((cols_sel,))
elif isinstance(cols_sel, Series):
cols_sel = _make_col(cols_sel)
elif isinstance(cols_sel, slice) and cols_sel != slice(None):
Expand All @@ -431,17 +430,24 @@ def raiseNotImplemented(description):
columns = self._kdf._internal.data_scols
elif isinstance(cols_sel, spark.Column):
columns = [cols_sel]
column_index = None
elif all(isinstance(key, Series) for key in cols_sel):
columns = [_make_col(key) for key in cols_sel]
column_index = None
else:
if column_index is not None:
column_to_index = list(zip(self._kdf._internal.data_columns,
self._kdf._internal.column_index))
columns, column_index = zip(*[(_make_col(column), idx)
for key in cols_sel
for column, idx in column_to_index
if idx[0] == key])
columns, column_index = list(columns), list(column_index)
else:
columns = [_make_col(c) for c in cols_sel]
column_to_index = list(zip(self._kdf._internal.data_columns,
self._kdf._internal.column_index))
columns = []
column_index = []
for key in cols_sel:
found = False
for column, idx in column_to_index:
if idx[0] == key:
columns.append(_make_col(column))
column_index.append(idx)
found = True
if not found:
raise KeyError("['{}'] not in index".format(key))

try:
kdf = DataFrame(sdf.select(self._kdf._internal.index_scols + columns))
Expand Down Expand Up @@ -686,11 +692,12 @@ def raiseNotImplemented(description):
.format([col._jc.toString() for col in columns]))

column_index = self._kdf._internal.column_index
if column_index is not None:
if cols_sel is not None and isinstance(cols_sel, (Series, int)):
if cols_sel is not None:
if isinstance(cols_sel, (Series, int)):
column_index = None
else:
column_index = pd.MultiIndex.from_tuples(column_index)[cols_sel].tolist()
column_index = \
pd.MultiIndex.from_tuples(self._kdf._internal.column_index)[cols_sel].tolist()

kdf._internal = kdf._internal.copy(
data_columns=kdf._internal.data_columns[-len(columns):],
Expand Down
54 changes: 39 additions & 15 deletions databricks/koalas/internal.py
Expand Up @@ -33,7 +33,7 @@

from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
from databricks.koalas.typedef import infer_pd_series_spark_type
from databricks.koalas.utils import default_session, lazy_property, scol_for
from databricks.koalas.utils import column_index_level, default_session, lazy_property, scol_for


IndexMap = Tuple[str, Optional[str]]
Expand Down Expand Up @@ -386,16 +386,20 @@ def __init__(self, sdf: spark.DataFrame,
if scol is not None:
self._data_columns = sdf.select(scol).columns
column_index = None
column_index_names = None
elif data_columns is None:
index_columns = set(index_column for index_column, _ in self._index_map)
self._data_columns = [column for column in sdf.columns if column not in index_columns]
else:
self._data_columns = data_columns

assert column_index is None or (len(column_index) == len(self._data_columns) and
all(isinstance(i, tuple) for i in column_index) and
len(set(len(i) for i in column_index)) <= 1)
self._column_index = column_index
if column_index is None:
self._column_index = [(col,) for col in self._data_columns]
else:
assert (len(column_index) == len(self._data_columns) and
all(isinstance(i, tuple) for i in column_index) and
len(set(len(i) for i in column_index)) <= 1)
self._column_index = column_index

if column_index_names is not None and not is_list_like(column_index_names):
raise ValueError('Column_index_names should be list-like or None for a MultiIndex')
Expand Down Expand Up @@ -550,10 +554,15 @@ def scol(self) -> Optional[spark.Column]:
return self._scol

@property
def column_index(self) -> Optional[List[Tuple[str]]]:
def column_index(self) -> List[Tuple[str]]:
""" Return the managed column index. """
return self._column_index

@lazy_property
def column_index_level(self) -> int:
""" Return the level of the column index. """
return column_index_level(self._column_index)

@property
def column_index_names(self) -> Optional[List[str]]:
""" Return names of the index levels. """
Expand All @@ -562,7 +571,16 @@ def column_index_names(self) -> Optional[List[str]]:
@lazy_property
def spark_df(self) -> spark.DataFrame:
""" Return as Spark DataFrame. """
return self._sdf.select(self.scols)
index_columns = set(self.index_columns)
data_columns = []
for column, idx in zip(self._data_columns, self.column_index):
if column not in index_columns:
scol = self.scol_for(column)
name = str(idx) if len(idx) > 1 else idx[0]
if column != name:
scol = scol.alias(name)
data_columns.append(scol)
return self._sdf.select(self.index_scols + data_columns)

@lazy_property
def pandas_df(self):
Expand All @@ -580,19 +598,18 @@ def pandas_df(self):
drop = index_field not in self.data_columns
pdf = pdf.set_index(index_field, drop=drop, append=append)
append = True
pdf = pdf[self.data_columns]
pdf = pdf[[str(name) if len(name) > 1 else name[0] for name in self.column_index]]

if self._column_index is not None:
if self.column_index_level > 1:
pdf.columns = pd.MultiIndex.from_tuples(self._column_index)
else:
pdf.columns = [idx[0] for idx in self._column_index]
if self._column_index_names is not None:
pdf.columns.names = self._column_index_names

index_names = self.index_names
if len(index_names) > 0:
if isinstance(pdf.index, pd.MultiIndex):
pdf.index.names = index_names
else:
pdf.index.name = index_names[0]
pdf.index.names = index_names
return pdf

def copy(self, sdf: Union[spark.DataFrame, _NoValueType] = _NoValue,
Expand All @@ -617,10 +634,17 @@ def copy(self, sdf: Union[spark.DataFrame, _NoValueType] = _NoValue,
index_map = self._index_map
if scol is _NoValue:
scol = self._scol
# FIXME(ueshin): this if-clause should be removed.
if column_index is _NoValue:
if data_columns is not _NoValue:
column_index = None
else:
column_index = self._column_index
if data_columns is _NoValue:
data_columns = self._data_columns
if column_index is _NoValue:
column_index = self._column_index
# FIXME(ueshin): this if-clause should be used instead of the above.
# if column_index is _NoValue:
# column_index = self._column_index
if column_index_names is _NoValue:
column_index_names = self._column_index_names
return _InternalFrame(sdf, index_map=index_map, scol=scol, data_columns=data_columns,
Expand Down
4 changes: 2 additions & 2 deletions databricks/koalas/series.py
Expand Up @@ -1795,7 +1795,7 @@ def add_prefix(self, prefix):
scol_for(sdf, index_column)).alias(index_column)
for index_column in internal.index_columns] + internal.data_columns)
kdf._internal = internal.copy(sdf=sdf)
return Series(kdf._internal.copy(scol=self._scol), anchor=kdf)
return _col(kdf)

def add_suffix(self, suffix):
"""
Expand Down Expand Up @@ -1845,7 +1845,7 @@ def add_suffix(self, suffix):
F.lit(suffix)).alias(index_column)
for index_column in internal.index_columns] + internal.data_columns)
kdf._internal = internal.copy(sdf=sdf)
return Series(kdf._internal.copy(scol=self._scol), anchor=kdf)
return _col(kdf)

def corr(self, other, method='pearson'):
"""
Expand Down

0 comments on commit e694826

Please sign in to comment.