diff --git a/databricks/koalas/base.py b/databricks/koalas/base.py index 4647f7051e..f44058c09e 100644 --- a/databricks/koalas/base.py +++ b/databricks/koalas/base.py @@ -57,6 +57,21 @@ def wrapper(self, *args): args = [arg._scol if isinstance(arg, IndexOpsMixin) else arg for arg in args] scol = f(self._scol, *args) + # check if `f` is a comparison operator + comp_ops = ['eq', 'ne', 'lt', 'le', 'ge', 'gt'] + is_comp_op = any(f == getattr(spark.Column, '__{}__'.format(comp_op)) + for comp_op in comp_ops) + + if is_comp_op: + filler = f == spark.Column.__ne__ + scol = F.when(scol.isNull(), filler).otherwise(scol) + + elif f == spark.Column.__or__: + scol = F.when(self._scol.isNull() | scol.isNull(), False).otherwise(scol) + + elif f == spark.Column.__and__: + scol = F.when(scol.isNull(), False).otherwise(scol) + return self._with_new_scol(scol) else: # Different DataFrame anchors @@ -182,7 +197,7 @@ def __rfloordiv__(self, other): __pow__ = _column_op(spark.Column.__pow__) __rpow__ = _column_op(spark.Column.__rpow__) - # logistic operators + # comparison operators __eq__ = _column_op(spark.Column.__eq__) __ne__ = _column_op(spark.Column.__ne__) __lt__ = _column_op(spark.Column.__lt__) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 8b7a715d1f..2094d9713f 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -749,11 +749,11 @@ def eq(self, other): ... index=['a', 'b', 'c', 'd'], columns=['a', 'b']) >>> df.eq(1) - a b - a True True - b False None - c False True - d False None + a b + a True True + b False False + c False True + d False False """ return self == other @@ -770,9 +770,9 @@ def gt(self, other): >>> df.gt(2) a b a False False - b False None + b False False c True False - d True None + d True False """ return self > other @@ -785,11 +785,11 @@ def ge(self, other): ... index=['a', 'b', 'c', 'd'], columns=['a', 'b']) >>> df.ge(1) - a b - a True True - b True None - c True True - d True None + a b + a True True + b True False + c True True + d True False """ return self >= other @@ -804,9 +804,9 @@ def lt(self, other): >>> df.lt(1) a b a False False - b False None + b False False c False False - d False None + d False False """ return self < other @@ -819,11 +819,11 @@ def le(self, other): ... index=['a', 'b', 'c', 'd'], columns=['a', 'b']) >>> df.le(2) - a b - a True True - b True None - c False True - d False None + a b + a True True + b True False + c False True + d False False """ return self <= other @@ -838,9 +838,9 @@ def ne(self, other): >>> df.ne(1) a b a False False - b True None + b True True c True False - d True None + d True True """ return self != other diff --git a/databricks/koalas/series.py b/databricks/koalas/series.py index 53200c8c86..6cc74a9d1c 100644 --- a/databricks/koalas/series.py +++ b/databricks/koalas/series.py @@ -539,11 +539,11 @@ def eq(self, other): Name: a, dtype: bool >>> df.b.eq(1) - a True - b None - c True - d None - Name: b, dtype: object + a True + b False + c True + d False + Name: b, dtype: bool """ return (self == other).rename(self.name) @@ -566,10 +566,10 @@ def gt(self, other): >>> df.b.gt(1) a False - b None + b False c False - d None - Name: b, dtype: object + d False + Name: b, dtype: bool """ return (self > other).rename(self.name) @@ -590,10 +590,10 @@ def ge(self, other): >>> df.b.ge(2) a False - b None + b False c False - d None - Name: b, dtype: object + d False + Name: b, dtype: bool """ return (self >= other).rename(self.name) @@ -613,11 +613,11 @@ def lt(self, other): Name: a, dtype: bool >>> df.b.lt(2) - a True - b None - c True - d None - Name: b, dtype: object + a True + b False + c True + d False + Name: b, dtype: bool """ return (self < other).rename(self.name) @@ -637,11 +637,11 @@ def le(self, other): Name: a, dtype: bool >>> df.b.le(2) - a True - b None - c True - d None - Name: b, dtype: object + a True + b False + c True + d False + Name: b, dtype: bool """ return (self <= other).rename(self.name) @@ -662,10 +662,10 @@ def ne(self, other): >>> df.b.ne(1) a False - b None + b True c False - d None - Name: b, dtype: object + d True + Name: b, dtype: bool """ return (self != other).rename(self.name) diff --git a/databricks/koalas/tests/test_series.py b/databricks/koalas/tests/test_series.py index e5fa79e596..a02641bfa4 100644 --- a/databricks/koalas/tests/test_series.py +++ b/databricks/koalas/tests/test_series.py @@ -177,6 +177,26 @@ def test_values_property(self): with self.assertRaises(NotImplementedError, msg=msg): kser.values + def test_or(self): + pdf = pd.DataFrame({ + 'left': [True, False, True, False, np.nan, np.nan, True, False, np.nan], + 'right': [True, False, False, True, True, False, np.nan, np.nan, np.nan] + }) + kdf = ks.from_pandas(pdf) + + self.assert_eq(pdf['left'] | pdf['right'], + kdf['left'] | kdf['right']) + + def test_and(self): + pdf = pd.DataFrame({ + 'left': [True, False, True, False, np.nan, np.nan, True, False, np.nan], + 'right': [True, False, False, True, True, False, np.nan, np.nan, np.nan] + }) + kdf = ks.from_pandas(pdf) + + self.assert_eq(pdf['left'] & pdf['right'], + kdf['left'] & kdf['right']) + def test_to_numpy(self): pser = pd.Series([1, 2, 3, 4, 5, 6, 7], name='x')