-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
554 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
"""This module contains implementations of the interval identifier wrangler. | ||
""" | ||
|
||
from pyspark.sql import DataFrame, Window | ||
from pyspark.sql import functions as F | ||
|
||
from pywrangler.wranglers.interfaces import IntervalIdentifier | ||
from pywrangler.wranglers.spark.base import SparkSingleNoFit | ||
|
||
|
||
class VectorizedCumSum(SparkSingleNoFit, IntervalIdentifier): | ||
"""Sophisticated approach avoiding python UDFs. However multiple windows | ||
are necessary. | ||
First, get enumeration of all intervals (valid and invalid). Every | ||
time a start or end marker is encountered, increase interval id by one. | ||
The end marker is shifted by one to include the end marker in the | ||
current interval. This is realized via the cumulative sum of boolean | ||
series of start markers and shifted end markers. | ||
Second, separate valid from invalid intervals by ensuring the presence | ||
of both start and end markers per interval id. | ||
Third, numerate valid intervals starting with 1 and set invalid | ||
intervals to 0. | ||
""" | ||
|
||
def validate_input(self, df: DataFrame): | ||
"""Checks input data frame in regard to column names. | ||
Parameters | ||
---------- | ||
df: pyspark.sql.Dataframe | ||
Dataframe to be validated. | ||
""" | ||
|
||
self.validate_columns(df, self.marker_column) | ||
self.validate_columns(df, self.order_columns) | ||
self.validate_columns(df, self.groupby_columns) | ||
|
||
if self.order_columns is None: | ||
raise ValueError("Please define an order column. Pyspark " | ||
"dataframes have no implicit order unlike pandas " | ||
"dataframes.") | ||
|
||
def transform(self, df: DataFrame) -> DataFrame: | ||
"""Extract interval ids from given dataframe. | ||
Parameters | ||
---------- | ||
df: pyspark.sql.Dataframe | ||
Returns | ||
------- | ||
result: pyspark.sql.Dataframe | ||
Same columns as original dataframe plus the new interval id column. | ||
""" | ||
|
||
# check input | ||
self.validate_input(df) | ||
|
||
# define window specs | ||
orderby = self.prepare_orderby(self.order_columns, self.ascending) | ||
groupby = self.groupby_columns or [] | ||
|
||
w_lag = Window.partitionBy(list(groupby)).orderBy(orderby) | ||
w_id = Window.partitionBy(list(groupby) + [self.target_column_name]) | ||
|
||
# get boolean series with start and end markers | ||
marker_col = F.col(self.marker_column) | ||
bool_start = (marker_col == self.marker_start).cast("integer") | ||
bool_end = (marker_col == self.marker_end).cast("integer") | ||
bool_start_end = bool_start + bool_end | ||
|
||
# shifting the close marker allows cumulative sum to include the end | ||
bool_end_shift = F.lag(bool_end, default=1).over(w_lag).cast("integer") | ||
bool_start_end_shift = bool_start + bool_end_shift | ||
|
||
# get increasing ids for intervals (in/valid) with cumsum | ||
ser_id = F.sum(bool_start_end_shift).over(w_lag) | ||
|
||
# separate valid vs invalid: ids with start AND end marker are valid | ||
bool_valid = F.sum(bool_start_end).over(w_id) == 2 | ||
valid_ids = F.when(bool_valid, ser_id).otherwise(0) | ||
|
||
# re-numerate ids from 1 to x and fill invalid with 0 | ||
valid_ids_shift = F.lag(valid_ids, default=0).over(w_lag) | ||
valid_ids_diff = valid_ids_shift - valid_ids | ||
valid_ids_increase = (valid_ids_diff < 0).cast("integer") | ||
|
||
renumerate = F.sum(valid_ids_increase).over(w_lag) | ||
renumerate_adjusted = F.when(bool_valid, renumerate).otherwise(0) | ||
|
||
# ser_id needs be created temporarily for renumerate_adjusted | ||
return df.withColumn(self.target_column_name, ser_id) \ | ||
.withColumn(self.target_column_name, renumerate_adjusted) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""This module contains helper functions for testing. | ||
""" | ||
|
||
import pandas as pd | ||
from pyspark.sql import DataFrame | ||
|
||
from pywrangler.util.types import TYPE_COLUMNS | ||
|
||
try: | ||
from pandas.testing import assert_frame_equal | ||
except ImportError: | ||
from pandas.util.testing import assert_frame_equal | ||
|
||
|
||
def assert_spark_pandas_equality(df_spark: DataFrame, | ||
df_pandas: pd.DataFrame, | ||
orderby: TYPE_COLUMNS = None): | ||
"""Compare a spark and pandas dataframe in regard to content equality. | ||
Spark dataframes do not have a specific index or column order due to their | ||
distributed nature. In contrast, a test for equality for pandas dataframes | ||
respects index and column order. Therefore, the test for equality between a | ||
spark and pandas dataframe will ignore index and column order on purpose. | ||
Testing spark dataframes content is most simple while converting to pandas | ||
dataframes and having test data as pandas dataframes, too. | ||
To ensure index order is ignored, both dataframes need be sorted by all or | ||
given columns `orderby`. | ||
Parameters | ||
---------- | ||
df_spark: pyspark.sql.DataFrame | ||
Spark dataframe to be tested for equality. | ||
df_pandas: pd.DataFrame | ||
Pandas dataframe to be tested for equality. | ||
orderby: iterable, optional | ||
Columns to be sorted for correct index order. | ||
Returns | ||
------- | ||
None but asserts if dataframes are not equal. | ||
""" | ||
|
||
orderby = orderby or df_pandas.columns.tolist() | ||
|
||
def prepare_compare(df): | ||
return df.sort_values(orderby).reset_index(drop=True) | ||
|
||
df_spark = prepare_compare(df_spark.toPandas()) | ||
df_pandas = prepare_compare(df_pandas) | ||
|
||
assert_frame_equal(df_spark, df_pandas, check_like=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
|
||
# Different pandas/pyspark/dask versions are tested separately to avoid | ||
# irrelevant tests to be run. For example, no spark tests need to be run | ||
# when pandas wranglers are tested on older pandas versions. | ||
|
||
# However, code coverage drops due to many skipped tests. Therefore, there is a | ||
# master version (marked via env variables) which includes all tests for | ||
# pandas/pyspark/dask for the newest available versions which is subject to | ||
# code coverage. Non master versions will not be included in code coverage. | ||
|
||
if [[ $ENV_STRING == *"master"* ]]; then | ||
coveralls --verbose | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.