Skip to content

Commit

Permalink
Add a basic selection. (#14)
Browse files Browse the repository at this point in the history
* Add a basic selection.

* workaround.

* Use `functools.wraps`.

* Address comments.

* Add a comment.

* Rename.
  • Loading branch information
ueshin authored and thunterdb committed Jan 29, 2019
1 parent 38dbd37 commit bf337a2
Show file tree
Hide file tree
Showing 8 changed files with 5,424 additions and 2 deletions.
5,149 changes: 5,149 additions & 0 deletions data/sample_stocks.csv

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions pandorable_sparky/exceptions.py
@@ -0,0 +1,29 @@
"""
Exceptions/Errors used in pandorable_sparky.
"""


def code_change_hint(pandas_function, spark_target_function):
if pandas_function is not None and spark_target_function is not None:
return "You are trying to use pandas function {}, use spark function {}" \
.format(pandas_function, spark_target_function)
elif pandas_function is not None and spark_target_function is None:
return ("You are trying to use pandas function {}, checkout the spark "
"user guide to find a relevant function").format(pandas_function)
elif pandas_function is None and spark_target_function is not None:
return "Use spark function {}".format(spark_target_function)
else: # both none
return "Checkout the spark user guide to find a relevant function"


class SparkPandasNotImplementedError(NotImplementedError):

def __init__(self, pandas_function=None, spark_target_function=None, description=""):
self.pandas_source = pandas_function
self.spark_target = spark_target_function
hint = code_change_hint(pandas_function, spark_target_function)
if len(description) > 0:
description += " " + hint
else:
description = hint
super(SparkPandasNotImplementedError, self).__init__(description)
32 changes: 32 additions & 0 deletions pandorable_sparky/groups.py
@@ -1,3 +1,11 @@
"""
A wrapper for GroupedData to behave similar to pandas.
"""
import sys

if sys.version > '3':
basestring = unicode = str


class PandasLikeGroupBy(object):
"""
Expand Down Expand Up @@ -32,6 +40,30 @@ def __getattr__(self, key):
except KeyError as e:
raise AttributeError(e)

def aggregate(self, func_or_funcs, *args, **kwargs):
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions can be built-in aggregation functions, such as `avg`,
`max`, `min`, `sum`, `count`.
:param func_or_funcs: a dict mapping from column name (string) to aggregate functions
(string).
"""
if not isinstance(func_or_funcs, dict) or \
not all(isinstance(key, basestring) and isinstance(value, basestring)
for key, value in func_or_funcs.items()):
raise ValueError("aggs must be a dict mapping from column name (string) to aggregate "
"functions (string).")
df = self._groupdata.agg(func_or_funcs)

reorder = ['%s(%s)' % (value, key) for key, value in iter(func_or_funcs.items())]
df = df._spark_select(reorder)
df.columns = [key for key in iter(func_or_funcs.keys())]

return df

agg = aggregate

def count(self):
return self._groupdata.count()

Expand Down
105 changes: 105 additions & 0 deletions pandorable_sparky/selection.py
@@ -0,0 +1,105 @@
"""
A locator for PandasLikeDataFrame.
"""
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType

from .exceptions import SparkPandasNotImplementedError


def _make_col(c):
if isinstance(c, Column):
return c
elif isinstance(c, str):
return col(c)
else:
raise SparkPandasNotImplementedError(
description="Can only convert a string to a column type")


def _unfold(key):
about_cols = """Can only select columns either by name or reference or all"""

if (not isinstance(key, tuple)) or (len(key) != 2):
raise NotImplementedError("Only accepts pairs of candidates")

rows, cols = key
# make cols a 1-tuple of string if a single string
if isinstance(cols, (str, Column)):
cols = (cols,)
elif isinstance(cols, slice) and cols != slice(None):
raise SparkPandasNotImplementedError(
description=about_cols,
pandas_source="loc",
spark_target_function="select, where, withColumn")
elif isinstance(cols, slice) and cols == slice(None):
cols = ("*",)

return rows, cols


class SparkDataFrameLocator(object):
"""
A locator to slice a group of rows and columns by conditional and label(s).
Allowed inputs are a slice with all indices or conditional for rows, and string(s) or
:class:`Column`(s) for cols.
"""

def __init__(self, df):
self.df = df

def __getitem__(self, key):

about_rows = """Can only slice with all indices or a column that evaluates to Boolean"""

rows, cols = _unfold(key)

if isinstance(rows, slice) and rows != slice(None):
raise SparkPandasNotImplementedError(
description=about_rows,
pandas_source=".loc[..., ...]",
spark_target_function="select, where")
elif isinstance(rows, slice) and rows == slice(None):
df = self.df
else: # not isinstance(rows, slice):
try:
assert isinstance(self.df._spark_select(rows).schema.fields[0].dataType,
BooleanType)
df = self.df._spark_where(rows)
except Exception as e:
raise SparkPandasNotImplementedError(
description=about_rows,
pandas_source=".loc[..., ...]",
spark_target_function="select, where")
return df._spark_select([_make_col(c) for c in cols])

def __setitem__(self, key, value):

if (not isinstance(key, tuple)) or (len(key) != 2):
raise NotImplementedError("Only accepts pairs of candidates")

rows, cols = key

if (not isinstance(rows, slice)) or (rows != slice(None)):
raise SparkPandasNotImplementedError(
description="""Can only assign value to the whole dataframe, the row index
has to be `slice(None)` or `:`""",
pandas_source=".loc[..., ...] = ...",
spark_target_function="withColumn, select")

if not isinstance(cols, str):
raise ValueError("""only column names can be assigned""")

if isinstance(value, Column):
df = self.df._spark_withColumn(cols, value)
elif isinstance(value, DataFrame) and len(value.columns) == 1:
df = self.df._spark_withColumn(cols, col(value.columns[0]))
elif isinstance(value, DataFrame) and len(value.columns) != 1:
raise ValueError("Only a dataframe with one column can be assigned")
else:
raise ValueError("Only a column or dataframe with single column can be assigned")

from .structures import _reassign_jdf
_reassign_jdf(self.df, df)
18 changes: 16 additions & 2 deletions pandorable_sparky/structures.py
@@ -1,9 +1,13 @@
"""
Base classes to be monkey-patched to DataFrame/Column to behave similar to pandas DataFrame/Series.
"""
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql import DataFrame, Column
from pyspark.sql.types import StructType

from .selection import SparkDataFrameLocator
from ._dask_stubs.utils import derived_from
from ._dask_stubs.compatibility import string_types

Expand Down Expand Up @@ -163,6 +167,13 @@ def assign(self, **kwargs):
df = df.withColumn(name, c)
return df

@property
def loc(self):
return SparkDataFrameLocator(self)

def copy(self):
return DataFrame(self._jdf, self.sql_ctx)

def head(self, n=5):
l = self.take(n)
df0 = self.sql_ctx.createDataFrame(l, schema=self.schema)
Expand Down Expand Up @@ -201,8 +212,11 @@ def get(self, key, default=None):
except (KeyError, ValueError, IndexError):
return default

def groupby(self, *cols):
gp = self._spark_groupby(*cols)
def sort_values(self, by):
return self._spark_sort(by)

def groupby(self, by):
gp = self._spark_groupby(by)
from .groups import PandasLikeGroupBy
return PandasLikeGroupBy(self, gp, None)

Expand Down
32 changes: 32 additions & 0 deletions pandorable_sparky/testing/utils.py
@@ -1,3 +1,4 @@
import functools
import shutil
import sys
import tempfile
Expand Down Expand Up @@ -215,3 +216,34 @@ def temp_dir(self):
def temp_file(self):
with self.temp_dir() as tmp:
yield tempfile.mktemp(dir=tmp)


class ComparisonTestBase(ReusedSQLTestCase):

@property
def df(self):
return self.spark.createDataFrame(self.pdf)

@property
def pdf(self):
return self.df.toPandas()


def compare_both(f=None, almost=True):

if f is None:
return functools.partial(compare_both, almost=almost)
elif isinstance(f, bool):
return functools.partial(compare_both, almost=f)

@functools.wraps(f)
def wrapped(self):
if almost:
compare = self.assertPandasAlmostEqual
else:
compare = self.assertPandasEqual

for result_pandas, result_spark in zip(f(self, self.pdf), f(self, self.df)):
compare(result_pandas, result_spark.toPandas())

return wrapped
58 changes: 58 additions & 0 deletions pandorable_sparky/tests/test_etl.py
@@ -0,0 +1,58 @@
import unittest

import pandas as pd
import pandorable_sparky
import pyspark

from pandorable_sparky.testing.utils import ComparisonTestBase, compare_both


class EtlTest(ComparisonTestBase):

@property
def pdf(self):
return pd.read_csv('data/sample_stocks.csv')

@compare_both
def test_etl(self, df):
df1 = df.loc[:, 'Symbol Date Open High Low Close'.split()]
yield df1

df2 = df1.sort_values(by=["Symbol", "Date"])
yield df2

df3 = df2.groupby(by="Symbol").agg({
'Open': 'first',
'High': 'max',
'Low': 'min',
'Close': 'last'
})
yield df3

df4 = df2.copy()

df4.loc[:, 'signal_1'] = df4.Close - df4.Open
df4.loc[:, 'signal_2'] = df4.High - df4.Low

# df4.loc[:, 'signal_3'] = (df4.signal_2 - df4.signal_2.mean()) / df4.signal_2.std()
yield df4

df5 = df4.loc[df4.signal_1 > 0, ['Symbol', 'Date']]
yield df5

df6 = df4.loc[df4.signal_2 > 0, ['Symbol', 'Date']]
yield df6

# df7 = df4.loc[df4.signal_3 > 0, ['Symbol', 'Date']]
# yield df7


if __name__ == "__main__":
from pandorable_sparky.tests.test_etl import *

try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
3 changes: 3 additions & 0 deletions pandorable_sparky/utils.py
@@ -1,3 +1,6 @@
"""
Utilities to monkey patch PySpark used in pandorable_sparky.
"""
import pyspark.sql.dataframe as df
import pyspark.sql.column as col
import pyspark.sql.functions as F
Expand Down

0 comments on commit bf337a2

Please sign in to comment.