Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Series.where #922

Merged
merged 8 commits into from Oct 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion databricks/koalas/missing/series.py
Expand Up @@ -110,7 +110,6 @@ class _MissingPandasLikeSeries(object):
unstack = unsupported_function('unstack')
update = unsupported_function('update')
view = unsupported_function('view')
where = unsupported_function('where')

# Deprecated functions
as_blocks = unsupported_function('as_blocks', deprecated=True)
Expand Down
80 changes: 80 additions & 0 deletions databricks/koalas/series.py
Expand Up @@ -3605,6 +3605,86 @@ def replace(self, to_replace=None, value=None, regex=False) -> 'Series':

return self._with_new_scol(current)

def where(self, cond, other=np.nan):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic seems like pandas shares the same implementation internally. After this PR is merged, can you move this into _Frame class and implement DataFrame.where as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, i'm going to work right after this PR is merged

"""
Replace values where the condition is False.

Parameters
----------
cond : boolean Series
Where cond is True, keep the original value. Where False,
replace with corresponding value from other.
other : scalar, Series
Entries where cond is False are replaced with corresponding value from other.

Returns
-------
Series

Examples
--------

>>> from databricks.koalas.config import set_option, reset_option
>>> set_option("compute.ops_on_diff_frames", True)
>>> s1 = ks.Series([0, 1, 2, 3, 4])
>>> s2 = ks.Series([100, 200, 300, 400, 500])
>>> s1.where(s1 > 0).sort_index()
0 NaN
1 1.0
2 2.0
3 3.0
4 4.0
Name: 0, dtype: float64

>>> s1.where(s1 > 1, 10).sort_index()
0 10
1 10
2 2
3 3
4 4
Name: 0, dtype: int64

>>> s1.where(s1 > 1, s1 + 100).sort_index()
0 100
1 101
2 2
3 3
4 4
Name: 0, dtype: int64

>>> s1.where(s1 > 1, s2).sort_index()
0 100
1 200
2 2
3 3
4 4
Name: 0, dtype: int64

>>> reset_option("compute.ops_on_diff_frames")
"""
kdf = self.to_frame()
kdf['__tmp_cond_col__'] = cond
kdf['__tmp_other_col__'] = other
sdf = kdf._sdf
# above logic make spark dataframe looks like below:
# +-----------------+---+----------------+-----------------+
# |__index_level_0__| 0|__tmp_cond_col__|__tmp_other_col__|
# +-----------------+---+----------------+-----------------+
# | 0| 0| false| 100|
# | 1| 1| false| 200|
# | 3| 3| true| 400|
# | 2| 2| true| 300|
# | 4| 4| true| 500|
# +-----------------+---+----------------+-----------------+
data_col_name = self._internal.column_name_for(self._internal.column_index[0])
index_column = self._internal.index_columns[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic, I think this doesn't support multi-level index cases. Can you fix this please?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_columns can be multiple and we cannot just use the first one only.

condition = F.when(sdf['__tmp_cond_col__'], sdf[data_col_name]) \
.otherwise(sdf['__tmp_other_col__']).alias(data_col_name)
sdf = sdf.select(index_column, condition)
result = _col(ks.DataFrame(_InternalFrame(sdf=sdf, index_map=self._internal.index_map)))

return result

def xs(self, key, level=None):
"""
Return cross-section from the Series.
Expand Down
28 changes: 28 additions & 0 deletions databricks/koalas/tests/test_series.py
Expand Up @@ -31,10 +31,21 @@
from databricks.koalas.testing.utils import ReusedSQLTestCase, SQLTestUtils
from databricks.koalas.exceptions import PandasNotImplementedError
from databricks.koalas.missing.series import _MissingPandasLikeSeries
from databricks.koalas.config import set_option, reset_option


class SeriesTest(ReusedSQLTestCase, SQLTestUtils):

@classmethod
def setUpClass(cls):
super(SeriesTest, cls).setUpClass()
set_option("compute.ops_on_diff_frames", True)

@classmethod
def tearDownClass(cls):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic disable this. compute.ops_on_diff_frames is disabled by default because it costs a lot. We should move the test cases into OpsOnDiffFramesEnabledTest

reset_option("compute.ops_on_diff_frames")
super(SeriesTest, cls).tearDownClass()

@property
def pser(self):
return pd.Series([1, 2, 3, 4, 5, 6, 7], name='x')
Expand Down Expand Up @@ -742,6 +753,23 @@ def test_duplicates(self):
self.assert_eq(pser.drop_duplicates().sort_values(),
kser.drop_duplicates().sort_values())

def test_where(self):
pser1 = pd.Series([0, 1, 2, 3, 4], name=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test when compute.ops_on_diff_frames is off? I think we can still use a scalar values for other such as int.

pser2 = pd.Series([100, 200, 300, 400, 500], name=0)
kser1 = ks.from_pandas(pser1)
kser2 = ks.from_pandas(pser2)

self.assert_eq(repr(pser1.where(pser2 > 100)),
repr(kser1.where(kser2 > 100).sort_index()))

pser1 = pd.Series([-1, -2, -3, -4, -5], name=0)
pser2 = pd.Series([-100, -200, -300, -400, -500], name=0)
kser1 = ks.from_pandas(pser1)
kser2 = ks.from_pandas(pser2)

self.assert_eq(repr(pser1.where(pser2 < -250)),
repr(kser1.where(kser2 < -250).sort_index()))

def test_truncate(self):
pser1 = pd.Series([10, 20, 30, 40, 50, 60, 70], index=[1, 2, 3, 4, 5, 6, 7])
kser1 = ks.Series(pser1)
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/series.rst
Expand Up @@ -151,6 +151,7 @@ Reindexing / Selection / Label manipulation
Series.rename
Series.reset_index
Series.sample
Series.where
Series.truncate

Missing data handling
Expand Down