Skip to content

Commit

Permalink
Port basic tests from dask. (#15)
Browse files Browse the repository at this point in the history
* Port basic tests from dask.

* Address comments.
  • Loading branch information
ueshin authored and thunterdb committed Jan 29, 2019
1 parent cd79aaa commit 38dbd37
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 23 deletions.
49 changes: 43 additions & 6 deletions pandorable_sparky/structures.py
Expand Up @@ -20,6 +20,10 @@ class _Frame(object):
def max(self):
return _reduce_spark(self, F.max)

def compute(self):
"""Alias of `toPandas()` to mimic dask for easily porting tests."""
return self.toPandas()


class PandasLikeSeries(_Frame):
"""
Expand Down Expand Up @@ -58,14 +62,35 @@ def schema(self):
def shape(self):
return len(self),

@property
def name(self):
return self._jc.toString()

@name.setter
def name(self, name):
col = _col(self.to_dataframe().select(self.alias(name)))
anchor_wrap(col, self)
self._jc = col._jc
self._pandas_schema = None

def rename(self, name, inplace=False):
if inplace:
self.name = name
return self
else:
return _col(self.to_dataframe().select(self.alias(name)))

def to_dataframe(self):
if hasattr(self, "_spark_ref_dataframe"):
return self._spark_ref_dataframe.select(self)
n = self._pandas_orig_repr()
raise ValueError("No reference to a dataframe for column {}".format(n))

def toPandas(self):
return self.to_dataframe().toPandas()
return _col(self.to_dataframe().toPandas())

def head(self, n=5):
return _col(self.to_dataframe().head(n))

def unique(self):
# Pandas wants a series/array-like object
Expand All @@ -81,7 +106,7 @@ def _pandas_anchor(self) -> DataFrame:
# DANGER: will materialize.
def __iter__(self):
print("__iter__", self)
return _col(self.toPandas()).__iter__()
return self.toPandas().__iter__()

def __len__(self):
return len(self.to_dataframe())
Expand All @@ -103,9 +128,14 @@ def __str__(self):
return self._pandas_orig_repr()

def __repr__(self):
df = self.to_dataframe().head(max_display_count).toPandas()
c = df[df.columns[0]]
return repr(c)
return repr(self.head(max_display_count).toPandas())

def __dir__(self):
if not isinstance(self.schema, StructType):
fields = []
else:
fields = [f for f in self.schema.fieldNames() if ' ' not in f]
return super(Column, self).__dir__() + fields

def _pandas_orig_repr(self):
# TODO: figure out how to reuse the original one.
Expand Down Expand Up @@ -246,6 +276,10 @@ def __iter__(self):
def __len__(self):
return self._spark_count()

def __dir__(self):
fields = [f for f in self.schema.fieldNames() if ' ' not in f]
return super(DataFrame, self).__dir__() + fields

def _repr_html_(self):
return self.head(max_display_count).toPandas()._repr_html_()

Expand All @@ -272,7 +306,10 @@ def _rename(frame, names):
if isinstance(frame, Column):
assert isinstance(frame.schema, StructType)
old_names = frame.schema.fieldNames()
assert len(names) == len(old_names)
if len(old_names) != len(names):
raise ValueError(
"Length mismatch: Expected axis has %d elements, new values have %d elements"
% (len(old_names), len(names)))
for (old_name, new_name) in zip(old_names, names):
frame = frame.withColumnRenamed(old_name, new_name)
return frame
Expand Down
76 changes: 60 additions & 16 deletions pandorable_sparky/testing/utils.py
Expand Up @@ -4,8 +4,9 @@
import unittest
from contextlib import contextmanager

import pandas as pd
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import Column, DataFrame, SparkSession


class PySparkTestCase(unittest.TestCase):
Expand Down Expand Up @@ -140,21 +141,64 @@ def tearDownClass(cls):
super(ReusedSQLTestCase, cls).tearDownClass()
cls.spark.stop()

def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertTrue(expected.equals(result), msg=msg)

def assertPandasAlmostEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertEqual(expected.shape, result.shape, msg=msg)
for ecol, rcol in zip(expected.columns, result.columns):
self.assertEqual(str(ecol), str(rcol), msg=msg)
for eval, rval in zip(expected[ecol], result[rcol]):
self.assertAlmostEqual(eval, rval, msg=msg)
def assertPandasEqual(self, left, right):
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
msg = ("DataFrames are not equal: " +
"\n\nLeft:\n%s\n%s" % (left, left.dtypes) +
"\n\nRight:\n%s\n%s" % (right, right.dtypes))
self.assertTrue(left.equals(right), msg=msg)
elif isinstance(left, pd.Series) and isinstance(left, pd.Series):
msg = ("Series are not equal: " +
"\n\nLeft:\n%s\n%s" % (left, left.dtype) +
"\n\nRight:\n%s\n%s" % (right, right.dtype))
self.assertTrue((left == right).all(), msg=msg)
elif isinstance(left, pd.Index) and isinstance(left, pd.Index):
msg = ("Indices are not equal: " +
"\n\nLeft:\n%s\n%s" % (left, left.dtype) +
"\n\nRight:\n%s\n%s" % (right, right.dtype))
self.assertTrue((left == right).all(), msg=msg)
else:
raise ValueError("Unexpected values: (%s, %s)" % (left, right))

def assertPandasAlmostEqual(self, left, right):
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
msg = ("DataFrames are not almost equal: " +
"\n\nLeft:\n%s\n%s" % (left, left.dtypes) +
"\n\nRight:\n%s\n%s" % (right, right.dtypes))
self.assertEqual(left.shape, right.shape, msg=msg)
for lcol, rcol in zip(left.columns, right.columns):
self.assertEqual(str(lcol), str(rcol), msg=msg)
for lval, rval in zip(left[lcol], right[rcol]):
self.assertAlmostEqual(lval, rval, msg=msg)
elif isinstance(left, pd.Series) and isinstance(left, pd.Series):
msg = ("Series are not almost equal: " +
"\n\nLeft:\n%s\n%s" % (left, left.dtype) +
"\n\nRight:\n%s\n%s" % (right, right.dtype))
for lval, rval in zip(left, right):
self.assertAlmostEqual(lval, rval, msg=msg)
elif isinstance(left, pd.Index) and isinstance(left, pd.Index):
msg = ("Indices are not almost equal: " +
"\n\nLeft:\n%s\n%s" % (left, left.dtype) +
"\n\nRight:\n%s\n%s" % (right, right.dtype))
for lval, rval in zip(left, right):
self.assertAlmostEqual(lval, rval, msg=msg)
else:
raise ValueError("Unexpected values: (%s, %s)" % (left, right))

def assert_eq(self, left, right):
lpdf = self._to_pandas(left)
rpdf = self._to_pandas(right)
if isinstance(lpdf, (pd.DataFrame, pd.Series, pd.Index)):
self.assertPandasEqual(lpdf, rpdf)
else:
self.assertEqual(lpdf, rpdf)

@staticmethod
def _to_pandas(df):
if isinstance(df, (DataFrame, Column)):
return df.toPandas()
else:
return df


class TestUtils(object):
Expand Down
204 changes: 204 additions & 0 deletions pandorable_sparky/tests/test_dataframe.py
@@ -0,0 +1,204 @@
import unittest

import numpy as np
import pandas as pd
import pandorable_sparky
import pyspark
from pyspark.sql import Column, DataFrame

from pandorable_sparky.testing.utils import ReusedSQLTestCase, TestUtils


class DataFrameTest(ReusedSQLTestCase, TestUtils):

@property
def df(self):
return self.spark.createDataFrame(zip(
[1, 2, 3, 4, 5, 6, 7, 8, 9],
[4, 5, 6, 3, 2, 1, 0, 0, 0]
), schema='a int, b int')

@property
def full(self):
pdf = self.df.toPandas()
# TODO: pdf.index = [0, 1, 3, 5, 6, 8, 9, 9, 9]
return pdf

def test_Dataframe(self):
d = self.df
full = self.full

expected = pd.Series([2, 3, 4, 5, 6, 7, 8, 9, 10],
# TODO: index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
name='(a + 1)') # TODO: name='a'

self.assert_eq(d['a'] + 1, expected)

self.assert_eq(d.columns, pd.Index(['a', 'b']))

self.assert_eq(d[d['b'] > 2], full[full['b'] > 2])
# TODO: self.assert_eq(d[['a', 'b']], full[['a', 'b']])
self.assert_eq(d.a, full.a)
# TODO: assert d.b.mean().compute() == full.b.mean()
# TODO: assert np.allclose(d.b.var().compute(), full.b.var())
# TODO: assert np.allclose(d.b.std().compute(), full.b.std())

# TODO: assert d.index._name == d.index._name # this is deterministic

assert repr(d)

def test_head_tail(self):
d = self.df
full = self.full

self.assert_eq(d.head(2), full.head(2))
self.assert_eq(d.head(3), full.head(3))
self.assert_eq(d['a'].head(2), full['a'].head(2))
self.assert_eq(d['a'].head(3), full['a'].head(3))

# TODO: self.assert_eq(d.tail(2), full.tail(2))
# TODO: self.assert_eq(d.tail(3), full.tail(3))
# TODO: self.assert_eq(d['a'].tail(2), full['a'].tail(2))
# TODO: self.assert_eq(d['a'].tail(3), full['a'].tail(3))

@unittest.skip('TODO: support index')
def test_index_head(self):
d = self.df
full = self.full

self.assert_eq(d.index[:2], full.index[:2])
self.assert_eq(d.index[:3], full.index[:3])

def test_Series(self):
d = self.df
full = self.full

self.assertTrue(isinstance(d.a, Column))
self.assertTrue(isinstance(d.a + 1, Column))
# TODO: self.assert_eq(d + 1, full + 1)

@unittest.skip('TODO: support index')
def test_Index(self):
for case in [pd.DataFrame(np.random.randn(10, 5), index=list('abcdefghij')),
pd.DataFrame(np.random.randn(10, 5),
index=pd.date_range('2011-01-01', freq='D',
periods=10))]:
ddf = self.spark.createDataFrame(case)
self.assert_eq(ddf.index, case.index)

def test_attributes(self):
d = self.df

self.assertIn('a', dir(d))
self.assertNotIn('foo', dir(d))
self.assertRaises(AttributeError, lambda: d.foo)

df = self.spark.createDataFrame(pd.DataFrame({'a b c': [1, 2, 3]}))
self.assertNotIn('a b c', dir(df))
df = self.spark.createDataFrame(pd.DataFrame({'a': [1, 2], 5: [1, 2]}))
self.assertIn('a', dir(df))
self.assertNotIn(5, dir(df))

def test_column_names(self):
d = self.df

self.assert_eq(d.columns, pd.Index(['a', 'b']))
# TODO: self.assert_eq(d[['b', 'a']].columns, pd.Index(['b', 'a']))
self.assertEqual(d['a'].name, 'a')
self.assertEqual((d['a'] + 1).name, '(a + 1)') # TODO: 'a'
self.assertEqual((d['a'] + d['b']).name, '(a + b)') # TODO: None

@unittest.skip('TODO: support index')
def test_index_names(self):
d = self.df

self.assertIsNone(d.index.name)

idx = pd.Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], name='x')
df = pd.DataFrame(np.random.randn(10, 5), idx)
ddf = self.spark.createDataFrame(df)
self.assertEqual(ddf.index.name, 'x')

def test_rename_columns(self):
df = pd.DataFrame({'a': [1, 2, 3, 4, 5, 6, 7],
'b': [7, 6, 5, 4, 3, 2, 1]})
ddf = self.spark.createDataFrame(df)

ddf.columns = ['x', 'y']
df.columns = ['x', 'y']
self.assert_eq(ddf.columns, pd.Index(['x', 'y']))
self.assert_eq(ddf, df)

msg = "Length mismatch: Expected axis has 2 elements, new values have 4 elements"
with self.assertRaisesRegex(ValueError, msg):
ddf.columns = [1, 2, 3, 4]

# Multi-index columns
df = pd.DataFrame({('A', '0'): [1, 2, 2, 3], ('B', 1): [1, 2, 3, 4]})
ddf = self.spark.createDataFrame(df)

df.columns = ['x', 'y']
ddf.columns = ['x', 'y']
self.assert_eq(ddf.columns, pd.Index(['x', 'y']))
self.assert_eq(ddf, df)

def test_rename_series(self):
s = pd.Series([1, 2, 3, 4, 5, 6, 7], name='x')
ds = self.spark.createDataFrame(pd.DataFrame(s)).x

s.name = 'renamed'
ds.name = 'renamed'
self.assertEqual(ds.name, 'renamed')
self.assert_eq(ds, s)

# TODO: index
# ind = s.index
# dind = ds.index
# ind.name = 'renamed'
# dind.name = 'renamed'
# self.assertEqual(ind.name, 'renamed')
# self.assert_eq(dind, ind)

def test_rename_series_method(self):
# Series name
s = pd.Series([1, 2, 3, 4, 5, 6, 7], name='x')
ds = self.spark.createDataFrame(pd.DataFrame(s)).x

self.assert_eq(ds.rename('y'), s.rename('y'))
self.assertEqual(ds.name, 'x') # no mutation
# self.assert_eq(ds.rename(), s.rename())

ds.rename('z', inplace=True)
s.rename('z', inplace=True)
self.assertEqual(ds.name, 'z')
self.assert_eq(ds, s)

# Series index
s = pd.Series(['a', 'b', 'c', 'd', 'e', 'f', 'g'], name='x')
ds = self.spark.createDataFrame(pd.DataFrame(s)).x

# TODOD: index
# res = ds.rename(lambda x: x ** 2)
# self.assert_eq(res, s.rename(lambda x: x ** 2))

# res = ds.rename(s)
# self.assert_eq(res, s.rename(s))

# res = ds.rename(ds)
# self.assert_eq(res, s.rename(s))

# res = ds.rename(lambda x: x**2, inplace=True)
# self.assertis(res, ds)
# s.rename(lambda x: x**2, inplace=True)
# self.assert_eq(ds, s)


if __name__ == "__main__":
from pandorable_sparky.tests.test_dataframe import *

try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

0 comments on commit 38dbd37

Please sign in to comment.