From e694826c81ca811f5e09475a357a428c388696d0 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 18 Aug 2019 21:04:58 -0700 Subject: [PATCH] Always use `column_index`. (#648) Now that we introduced multi-index columns, we should use `column_index` rather than `data_columns` to be consistent access between single and multi index. We still have a ways to go, so this includes a workaround to work with the `data_columns`-based approach. I will address them in the separate PRs. --- databricks/koalas/frame.py | 86 ++++++++++++++--------------------- databricks/koalas/indexes.py | 2 +- databricks/koalas/indexing.py | 43 ++++++++++-------- databricks/koalas/internal.py | 54 ++++++++++++++++------ databricks/koalas/series.py | 4 +- databricks/koalas/utils.py | 14 +++++- 6 files changed, 114 insertions(+), 89 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 3a6cb6026f..687afe3a2f 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -49,7 +49,7 @@ from databricks.koalas.internal import _InternalFrame, IndexMap from databricks.koalas.missing.frame import _MissingPandasLikeDataFrame from databricks.koalas.ml import corr -from databricks.koalas.utils import scol_for +from databricks.koalas.utils import column_index_level, scol_for from databricks.koalas.typedef import as_spark_type # These regular expression patterns are complied and defined here to avoid to compile the same @@ -330,7 +330,7 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=False): if axis in ('index', 0, None): exprs = [] num_args = len(signature(sfun).parameters) - for col in self._internal.data_columns: + for col, idx in zip(self._internal.data_columns, self._internal.column_index): col_sdf = self._internal.scol_for(col) col_type = self._internal.spark_type_for(col) @@ -351,12 +351,12 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=False): assert num_args == 2 # Pass in both the column and its data type if sfun accepts two args col_sdf = sfun(col_sdf, col_type) - exprs.append(col_sdf.alias(col)) + exprs.append(col_sdf.alias(str(idx) if len(idx) > 1 else idx[0])) sdf = self._sdf.select(*exprs) pdf = sdf.toPandas() - if self._internal.column_index is not None: + if self._internal.column_index_level > 1: pdf.columns = pd.MultiIndex.from_tuples(self._internal.column_index) assert len(pdf) == 1, (sdf, pdf) @@ -1923,7 +1923,7 @@ def rename(index): index_map=index_map, column_index=None) - if self._internal.column_index is not None: + if self._internal.column_index_level > 1: column_depth = len(self._internal.column_index[0]) if col_level >= column_depth: raise IndexError('Too many levels: Index has only {} levels, not {}' @@ -3788,10 +3788,10 @@ def pivot(self, index=None, columns=None, values=None): @property def columns(self): """The column labels of the DataFrame.""" - if self._internal.column_index is not None: + if self._internal.column_index_level > 1: columns = pd.MultiIndex.from_tuples(self._internal.column_index) else: - columns = pd.Index(self._internal.data_columns) + columns = pd.Index([idx[0] for idx in self._internal.column_index]) if self._internal.column_index_names is not None: columns.names = self._internal.column_index_names return columns @@ -5566,7 +5566,9 @@ def _cum(self, func, skipna: bool): sdf = self._sdf.select( self._internal.index_scols + [c._scol for c in applied]) - internal = self._internal.copy(sdf=sdf, data_columns=[c.name for c in applied]) + # FIXME(ueshin): no need to specify `column_index`. + internal = self._internal.copy(sdf=sdf, data_columns=[c.name for c in applied], + column_index=self._internal.column_index) return DataFrame(internal) # TODO: implements 'keep' parameters @@ -6336,41 +6338,39 @@ def _get_from_multiindex_column(self, key): if len(columns) == 0: raise KeyError(k) recursive = False - if all(len(idx) == 0 or idx[0] == '' for _, idx in columns): - # If idx is empty or the head is '', drill down recursively. + if all(len(idx) > 0 and idx[0] == '' for _, idx in columns): + # If the head is '', drill down recursively. recursive = True for i, (col, idx) in enumerate(columns): columns[i] = (col, tuple([str(key), *idx[1:]])) + column_index_names = None if self._internal.column_index_names is not None: # Manage column index names - column_index_level = set(len(idx) for _, idx in columns) - assert len(column_index_level) == 1 - column_index_level = list(column_index_level)[0] - column_index_names = self._internal.column_index_names[-column_index_level:] - if all(len(idx) == 1 for _, idx in columns): - # If len(idx) == 1, then the result is not MultiIndex anymore - sdf = self._sdf.select(self._internal.index_scols + - [self._internal.scol_for(col).alias(idx[0]) - for col, idx in columns]) - kdf_or_ser = DataFrame(self._internal.copy( - sdf=sdf, - data_columns=[idx[0] for _, idx in columns], - column_index=None, - column_index_names=column_index_names)) + level = column_index_level([idx for _, idx in columns]) + column_index_names = self._internal.column_index_names[-level:] + + if all(len(idx) == 0 for _, idx in columns): + try: + cols = set(col for col, _ in columns) + assert len(cols) == 1 + kdf_or_ser = \ + Series(self._internal.copy(scol=self._internal.scol_for(list(cols)[0])), + anchor=self) + except AnalysisException: + raise KeyError(key) else: - # Otherwise, the result is still MultiIndex and need to manage column_index. - sdf = self._sdf.select(self._internal.index_scols + - [self._internal.scol_for(col) for col, _ in columns]) kdf_or_ser = DataFrame(self._internal.copy( - sdf=sdf, data_columns=[col for col, _ in columns], column_index=[idx for _, idx in columns], column_index_names=column_index_names)) + if recursive: - kdf_or_ser = kdf_or_ser._pd_getitem(str(key)) + kdf_or_ser = kdf_or_ser._get_from_multiindex_column((str(key),)) if isinstance(kdf_or_ser, Series): - kdf_or_ser.name = str(key) + name = str(key) if len(key) > 1 else key[0] + if kdf_or_ser.name != name: + kdf_or_ser.name = name return kdf_or_ser def _pd_getitem(self, key): @@ -6378,19 +6378,9 @@ def _pd_getitem(self, key): if key is None: raise KeyError("none key") if isinstance(key, str): - if self._internal.column_index is not None: - return self._get_from_multiindex_column((key,)) - else: - try: - return Series(self._internal.copy(scol=self._internal.scol_for(key)), - anchor=self) - except AnalysisException: - raise KeyError(key) + return self._get_from_multiindex_column((key,)) if isinstance(key, tuple): - if self._internal.column_index is not None: - return self._get_from_multiindex_column(key) - else: - raise NotImplementedError(key) + return self._get_from_multiindex_column(key) elif np.isscalar(key): raise NotImplementedError(key) elif isinstance(key, slice): @@ -6478,7 +6468,6 @@ def assign_columns(kdf, this_columns, that_columns): self._internal = kdf._internal def __getattr__(self, key: str) -> Any: - from databricks.koalas.series import Series if key.startswith("__") or key.startswith("_pandas_") or key.startswith("_spark_"): raise AttributeError(key) if hasattr(_MissingPandasLikeDataFrame, key): @@ -6488,16 +6477,11 @@ def __getattr__(self, key: str) -> Any: else: return partial(property_or_func, self) - if self._internal.column_index is not None: - try: - return self._get_from_multiindex_column((key,)) - except KeyError: - raise AttributeError( - "'%s' object has no attribute '%s'" % (self.__class__.__name__, key)) - if key not in self._internal.data_columns: + try: + return self._get_from_multiindex_column((key,)) + except KeyError: raise AttributeError( "'%s' object has no attribute '%s'" % (self.__class__.__name__, key)) - return Series(self._internal.copy(scol=self._internal.scol_for(key)), anchor=self) def __len__(self): return self._sdf.count() diff --git a/databricks/koalas/indexes.py b/databricks/koalas/indexes.py index 67f684f4f8..9cf41ed1eb 100644 --- a/databricks/koalas/indexes.py +++ b/databricks/koalas/indexes.py @@ -113,7 +113,7 @@ def to_pandas(self) -> pd.Index: internal = self._kdf._internal.copy( sdf=sdf, index_map=[(sdf.schema[0].name, self._kdf._internal.index_names[0])], - data_columns=[]) + data_columns=[], column_index=[], column_index_names=None) return DataFrame(internal).to_pandas().index toPandas = to_pandas diff --git a/databricks/koalas/indexing.py b/databricks/koalas/indexing.py index 213937753e..1cdffb3b6a 100644 --- a/databricks/koalas/indexing.py +++ b/databricks/koalas/indexing.py @@ -27,6 +27,7 @@ from pyspark.sql.utils import AnalysisException from databricks.koalas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError +from databricks.koalas.utils import column_index_level def _make_col(c): @@ -145,7 +146,7 @@ def __getitem__(self, key): raise ValueError("'.at' only supports indices with level 1 right now") if self._ks is None: - if self._kdf._internal.column_index is not None: + if self._kdf._internal.column_index_level > 1: column = dict(zip(self._kdf._internal.column_index, self._kdf._internal.data_columns)).get(key[1], None) if column is None: @@ -416,10 +417,8 @@ def raiseNotImplemented(description): # make cols_sel a 1-tuple of string if a single string column_index = self._kdf._internal.column_index if isinstance(cols_sel, str): - if column_index is not None: - return self[rows_sel, [cols_sel]]._get_from_multiindex_column((cols_sel,)) - else: - cols_sel = _make_col(cols_sel) + kdf = DataFrame(self._kdf._internal.copy(sdf=sdf)) + return kdf._get_from_multiindex_column((cols_sel,)) elif isinstance(cols_sel, Series): cols_sel = _make_col(cols_sel) elif isinstance(cols_sel, slice) and cols_sel != slice(None): @@ -431,17 +430,24 @@ def raiseNotImplemented(description): columns = self._kdf._internal.data_scols elif isinstance(cols_sel, spark.Column): columns = [cols_sel] + column_index = None + elif all(isinstance(key, Series) for key in cols_sel): + columns = [_make_col(key) for key in cols_sel] + column_index = None else: - if column_index is not None: - column_to_index = list(zip(self._kdf._internal.data_columns, - self._kdf._internal.column_index)) - columns, column_index = zip(*[(_make_col(column), idx) - for key in cols_sel - for column, idx in column_to_index - if idx[0] == key]) - columns, column_index = list(columns), list(column_index) - else: - columns = [_make_col(c) for c in cols_sel] + column_to_index = list(zip(self._kdf._internal.data_columns, + self._kdf._internal.column_index)) + columns = [] + column_index = [] + for key in cols_sel: + found = False + for column, idx in column_to_index: + if idx[0] == key: + columns.append(_make_col(column)) + column_index.append(idx) + found = True + if not found: + raise KeyError("['{}'] not in index".format(key)) try: kdf = DataFrame(sdf.select(self._kdf._internal.index_scols + columns)) @@ -686,11 +692,12 @@ def raiseNotImplemented(description): .format([col._jc.toString() for col in columns])) column_index = self._kdf._internal.column_index - if column_index is not None: - if cols_sel is not None and isinstance(cols_sel, (Series, int)): + if cols_sel is not None: + if isinstance(cols_sel, (Series, int)): column_index = None else: - column_index = pd.MultiIndex.from_tuples(column_index)[cols_sel].tolist() + column_index = \ + pd.MultiIndex.from_tuples(self._kdf._internal.column_index)[cols_sel].tolist() kdf._internal = kdf._internal.copy( data_columns=kdf._internal.data_columns[-len(columns):], diff --git a/databricks/koalas/internal.py b/databricks/koalas/internal.py index 0f8a18eb6a..6a316cd3e5 100644 --- a/databricks/koalas/internal.py +++ b/databricks/koalas/internal.py @@ -33,7 +33,7 @@ from databricks import koalas as ks # For running doctests and reference resolution in PyCharm. from databricks.koalas.typedef import infer_pd_series_spark_type -from databricks.koalas.utils import default_session, lazy_property, scol_for +from databricks.koalas.utils import column_index_level, default_session, lazy_property, scol_for IndexMap = Tuple[str, Optional[str]] @@ -386,16 +386,20 @@ def __init__(self, sdf: spark.DataFrame, if scol is not None: self._data_columns = sdf.select(scol).columns column_index = None + column_index_names = None elif data_columns is None: index_columns = set(index_column for index_column, _ in self._index_map) self._data_columns = [column for column in sdf.columns if column not in index_columns] else: self._data_columns = data_columns - assert column_index is None or (len(column_index) == len(self._data_columns) and - all(isinstance(i, tuple) for i in column_index) and - len(set(len(i) for i in column_index)) <= 1) - self._column_index = column_index + if column_index is None: + self._column_index = [(col,) for col in self._data_columns] + else: + assert (len(column_index) == len(self._data_columns) and + all(isinstance(i, tuple) for i in column_index) and + len(set(len(i) for i in column_index)) <= 1) + self._column_index = column_index if column_index_names is not None and not is_list_like(column_index_names): raise ValueError('Column_index_names should be list-like or None for a MultiIndex') @@ -550,10 +554,15 @@ def scol(self) -> Optional[spark.Column]: return self._scol @property - def column_index(self) -> Optional[List[Tuple[str]]]: + def column_index(self) -> List[Tuple[str]]: """ Return the managed column index. """ return self._column_index + @lazy_property + def column_index_level(self) -> int: + """ Return the level of the column index. """ + return column_index_level(self._column_index) + @property def column_index_names(self) -> Optional[List[str]]: """ Return names of the index levels. """ @@ -562,7 +571,16 @@ def column_index_names(self) -> Optional[List[str]]: @lazy_property def spark_df(self) -> spark.DataFrame: """ Return as Spark DataFrame. """ - return self._sdf.select(self.scols) + index_columns = set(self.index_columns) + data_columns = [] + for column, idx in zip(self._data_columns, self.column_index): + if column not in index_columns: + scol = self.scol_for(column) + name = str(idx) if len(idx) > 1 else idx[0] + if column != name: + scol = scol.alias(name) + data_columns.append(scol) + return self._sdf.select(self.index_scols + data_columns) @lazy_property def pandas_df(self): @@ -580,19 +598,18 @@ def pandas_df(self): drop = index_field not in self.data_columns pdf = pdf.set_index(index_field, drop=drop, append=append) append = True - pdf = pdf[self.data_columns] + pdf = pdf[[str(name) if len(name) > 1 else name[0] for name in self.column_index]] - if self._column_index is not None: + if self.column_index_level > 1: pdf.columns = pd.MultiIndex.from_tuples(self._column_index) + else: + pdf.columns = [idx[0] for idx in self._column_index] if self._column_index_names is not None: pdf.columns.names = self._column_index_names index_names = self.index_names if len(index_names) > 0: - if isinstance(pdf.index, pd.MultiIndex): - pdf.index.names = index_names - else: - pdf.index.name = index_names[0] + pdf.index.names = index_names return pdf def copy(self, sdf: Union[spark.DataFrame, _NoValueType] = _NoValue, @@ -617,10 +634,17 @@ def copy(self, sdf: Union[spark.DataFrame, _NoValueType] = _NoValue, index_map = self._index_map if scol is _NoValue: scol = self._scol + # FIXME(ueshin): this if-clause should be removed. + if column_index is _NoValue: + if data_columns is not _NoValue: + column_index = None + else: + column_index = self._column_index if data_columns is _NoValue: data_columns = self._data_columns - if column_index is _NoValue: - column_index = self._column_index + # FIXME(ueshin): this if-clause should be used instead of the above. + # if column_index is _NoValue: + # column_index = self._column_index if column_index_names is _NoValue: column_index_names = self._column_index_names return _InternalFrame(sdf, index_map=index_map, scol=scol, data_columns=data_columns, diff --git a/databricks/koalas/series.py b/databricks/koalas/series.py index 6cc8478cb1..4fe7cf0a26 100644 --- a/databricks/koalas/series.py +++ b/databricks/koalas/series.py @@ -1795,7 +1795,7 @@ def add_prefix(self, prefix): scol_for(sdf, index_column)).alias(index_column) for index_column in internal.index_columns] + internal.data_columns) kdf._internal = internal.copy(sdf=sdf) - return Series(kdf._internal.copy(scol=self._scol), anchor=kdf) + return _col(kdf) def add_suffix(self, suffix): """ @@ -1845,7 +1845,7 @@ def add_suffix(self, suffix): F.lit(suffix)).alias(index_column) for index_column in internal.index_columns] + internal.data_columns) kdf._internal = internal.copy(sdf=sdf) - return Series(kdf._internal.copy(scol=self._scol), anchor=kdf) + return _col(kdf) def corr(self, other, method='pearson'): """ diff --git a/databricks/koalas/utils.py b/databricks/koalas/utils.py index dff592b137..0831d00904 100644 --- a/databricks/koalas/utils.py +++ b/databricks/koalas/utils.py @@ -18,9 +18,9 @@ """ import functools -import os from collections import OrderedDict -from typing import Callable, Dict, Union +import os +from typing import Callable, Dict, List, Tuple, Union from pyspark import sql as spark from pyspark.sql import functions as F @@ -351,3 +351,13 @@ def _lazy_property(self): def scol_for(sdf: spark.DataFrame, column_name: str) -> spark.Column: """ Return Spark Column for the given column name. """ return sdf['`{}`'.format(column_name)] + + +def column_index_level(column_index: List[Tuple[str]]) -> int: + """ Return the level of the column index. """ + if len(column_index) == 0: + return 0 + else: + levels = set(len(idx) for idx in column_index) + assert len(levels) == 1, levels + return list(levels)[0]