Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a basic selection. * workaround. * Use `functools.wraps`. * Address comments. * Add a comment. * Rename.
- Loading branch information
Showing
8 changed files
with
5,424 additions
and
2 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
Exceptions/Errors used in pandorable_sparky. | ||
""" | ||
|
||
|
||
def code_change_hint(pandas_function, spark_target_function): | ||
if pandas_function is not None and spark_target_function is not None: | ||
return "You are trying to use pandas function {}, use spark function {}" \ | ||
.format(pandas_function, spark_target_function) | ||
elif pandas_function is not None and spark_target_function is None: | ||
return ("You are trying to use pandas function {}, checkout the spark " | ||
"user guide to find a relevant function").format(pandas_function) | ||
elif pandas_function is None and spark_target_function is not None: | ||
return "Use spark function {}".format(spark_target_function) | ||
else: # both none | ||
return "Checkout the spark user guide to find a relevant function" | ||
|
||
|
||
class SparkPandasNotImplementedError(NotImplementedError): | ||
|
||
def __init__(self, pandas_function=None, spark_target_function=None, description=""): | ||
self.pandas_source = pandas_function | ||
self.spark_target = spark_target_function | ||
hint = code_change_hint(pandas_function, spark_target_function) | ||
if len(description) > 0: | ||
description += " " + hint | ||
else: | ||
description = hint | ||
super(SparkPandasNotImplementedError, self).__init__(description) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
A locator for PandasLikeDataFrame. | ||
""" | ||
from pyspark.sql import Column, DataFrame | ||
from pyspark.sql.functions import col | ||
from pyspark.sql.types import BooleanType | ||
|
||
from .exceptions import SparkPandasNotImplementedError | ||
|
||
|
||
def _make_col(c): | ||
if isinstance(c, Column): | ||
return c | ||
elif isinstance(c, str): | ||
return col(c) | ||
else: | ||
raise SparkPandasNotImplementedError( | ||
description="Can only convert a string to a column type") | ||
|
||
|
||
def _unfold(key): | ||
about_cols = """Can only select columns either by name or reference or all""" | ||
|
||
if (not isinstance(key, tuple)) or (len(key) != 2): | ||
raise NotImplementedError("Only accepts pairs of candidates") | ||
|
||
rows, cols = key | ||
# make cols a 1-tuple of string if a single string | ||
if isinstance(cols, (str, Column)): | ||
cols = (cols,) | ||
elif isinstance(cols, slice) and cols != slice(None): | ||
raise SparkPandasNotImplementedError( | ||
description=about_cols, | ||
pandas_source="loc", | ||
spark_target_function="select, where, withColumn") | ||
elif isinstance(cols, slice) and cols == slice(None): | ||
cols = ("*",) | ||
|
||
return rows, cols | ||
|
||
|
||
class SparkDataFrameLocator(object): | ||
""" | ||
A locator to slice a group of rows and columns by conditional and label(s). | ||
Allowed inputs are a slice with all indices or conditional for rows, and string(s) or | ||
:class:`Column`(s) for cols. | ||
""" | ||
|
||
def __init__(self, df): | ||
self.df = df | ||
|
||
def __getitem__(self, key): | ||
|
||
about_rows = """Can only slice with all indices or a column that evaluates to Boolean""" | ||
|
||
rows, cols = _unfold(key) | ||
|
||
if isinstance(rows, slice) and rows != slice(None): | ||
raise SparkPandasNotImplementedError( | ||
description=about_rows, | ||
pandas_source=".loc[..., ...]", | ||
spark_target_function="select, where") | ||
elif isinstance(rows, slice) and rows == slice(None): | ||
df = self.df | ||
else: # not isinstance(rows, slice): | ||
try: | ||
assert isinstance(self.df._spark_select(rows).schema.fields[0].dataType, | ||
BooleanType) | ||
df = self.df._spark_where(rows) | ||
except Exception as e: | ||
raise SparkPandasNotImplementedError( | ||
description=about_rows, | ||
pandas_source=".loc[..., ...]", | ||
spark_target_function="select, where") | ||
return df._spark_select([_make_col(c) for c in cols]) | ||
|
||
def __setitem__(self, key, value): | ||
|
||
if (not isinstance(key, tuple)) or (len(key) != 2): | ||
raise NotImplementedError("Only accepts pairs of candidates") | ||
|
||
rows, cols = key | ||
|
||
if (not isinstance(rows, slice)) or (rows != slice(None)): | ||
raise SparkPandasNotImplementedError( | ||
description="""Can only assign value to the whole dataframe, the row index | ||
has to be `slice(None)` or `:`""", | ||
pandas_source=".loc[..., ...] = ...", | ||
spark_target_function="withColumn, select") | ||
|
||
if not isinstance(cols, str): | ||
raise ValueError("""only column names can be assigned""") | ||
|
||
if isinstance(value, Column): | ||
df = self.df._spark_withColumn(cols, value) | ||
elif isinstance(value, DataFrame) and len(value.columns) == 1: | ||
df = self.df._spark_withColumn(cols, col(value.columns[0])) | ||
elif isinstance(value, DataFrame) and len(value.columns) != 1: | ||
raise ValueError("Only a dataframe with one column can be assigned") | ||
else: | ||
raise ValueError("Only a column or dataframe with single column can be assigned") | ||
|
||
from .structures import _reassign_jdf | ||
_reassign_jdf(self.df, df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import unittest | ||
|
||
import pandas as pd | ||
import pandorable_sparky | ||
import pyspark | ||
|
||
from pandorable_sparky.testing.utils import ComparisonTestBase, compare_both | ||
|
||
|
||
class EtlTest(ComparisonTestBase): | ||
|
||
@property | ||
def pdf(self): | ||
return pd.read_csv('data/sample_stocks.csv') | ||
|
||
@compare_both | ||
def test_etl(self, df): | ||
df1 = df.loc[:, 'Symbol Date Open High Low Close'.split()] | ||
yield df1 | ||
|
||
df2 = df1.sort_values(by=["Symbol", "Date"]) | ||
yield df2 | ||
|
||
df3 = df2.groupby(by="Symbol").agg({ | ||
'Open': 'first', | ||
'High': 'max', | ||
'Low': 'min', | ||
'Close': 'last' | ||
}) | ||
yield df3 | ||
|
||
df4 = df2.copy() | ||
|
||
df4.loc[:, 'signal_1'] = df4.Close - df4.Open | ||
df4.loc[:, 'signal_2'] = df4.High - df4.Low | ||
|
||
# df4.loc[:, 'signal_3'] = (df4.signal_2 - df4.signal_2.mean()) / df4.signal_2.std() | ||
yield df4 | ||
|
||
df5 = df4.loc[df4.signal_1 > 0, ['Symbol', 'Date']] | ||
yield df5 | ||
|
||
df6 = df4.loc[df4.signal_2 > 0, ['Symbol', 'Date']] | ||
yield df6 | ||
|
||
# df7 = df4.loc[df4.signal_3 > 0, ['Symbol', 'Date']] | ||
# yield df7 | ||
|
||
|
||
if __name__ == "__main__": | ||
from pandorable_sparky.tests.test_etl import * | ||
|
||
try: | ||
import xmlrunner | ||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') | ||
except ImportError: | ||
testRunner = None | ||
unittest.main(testRunner=testRunner, verbosity=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters