Skip to content

Commit

Permalink
Merge 1e3bbc9 into 0cddd33
Browse files Browse the repository at this point in the history
  • Loading branch information
mansenfranzen committed Jun 22, 2019
2 parents 0cddd33 + 1e3bbc9 commit 31b4133
Show file tree
Hide file tree
Showing 14 changed files with 554 additions and 32 deletions.
49 changes: 25 additions & 24 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,31 @@ python:
- '3.7'

env:
- ENV_STRING=master
- ENV_STRING=pandas0.24.1
- ENV_STRING=pandas0.24.0

- ENV_STRING=pandas0.23.4
- ENV_STRING=pandas0.23.3
- ENV_STRING=pandas0.23.2
- ENV_STRING=pandas0.23.1
- ENV_STRING=pandas0.23.0

- ENV_STRING=pandas0.22.0

- ENV_STRING=pandas0.21.1
- ENV_STRING=pandas0.21.0

- ENV_STRING=pandas0.20.3
- ENV_STRING=pandas0.20.2
- ENV_STRING=pandas0.20.1
- ENV_STRING=pandas0.20.0

- ENV_STRING=pandas0.19.2
- ENV_STRING=pandas0.19.1
- ENV_STRING=pandas0.19.0

- ENV_STRING=pyspark2.4.0
# - ENV_STRING=pandas0.24.0
#
# - ENV_STRING=pandas0.23.4
# - ENV_STRING=pandas0.23.3
# - ENV_STRING=pandas0.23.2
# - ENV_STRING=pandas0.23.1
# - ENV_STRING=pandas0.23.0
#
# - ENV_STRING=pandas0.22.0
#
# - ENV_STRING=pandas0.21.1
# - ENV_STRING=pandas0.21.0
#
# - ENV_STRING=pandas0.20.3
# - ENV_STRING=pandas0.20.2
# - ENV_STRING=pandas0.20.1
# - ENV_STRING=pandas0.20.0
#
# - ENV_STRING=pandas0.19.2
# - ENV_STRING=pandas0.19.1
# - ENV_STRING=pandas0.19.0
#
# - ENV_STRING=pyspark2.4.0
- ENV_STRING=pyspark2.3.1

- ENV_STRING=dask1.1.5
Expand Down Expand Up @@ -77,6 +78,6 @@ script:
- tox -e $(echo py$TRAVIS_PYTHON_VERSION-$ENV_STRING | tr -d .)

after_success:
- coveralls --verbose
- source tests/travis_coveralls_master.sh

cache: pip
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Version 0.1.0

This is the initial release of pywrangler.

- Add ``VectorizedCumSum`` pyspark implementation for ``IntervalIdentifier`` wrangler (`#7 <https://github.com/mansenfranzen/pywrangler/pull/7>`_).
- Add benchmark utilities for pandas, spark and dask wranglers (`#5 <https://github.com/mansenfranzen/pywrangler/pull/5>`_).
- Add sequential ``NaiveIterator`` and vectorized ``VectorizedCumSum`` pandas implementations for ``IntervalIdentifier`` wrangler (`#2 <https://github.com/mansenfranzen/pywrangler/pull/2>`_).
- Add ``PandasWrangler`` (`#2 <https://github.com/mansenfranzen/pywrangler/pull/2>`_).
Expand Down
2 changes: 1 addition & 1 deletion src/pywrangler/wranglers/pandas/interval_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class VectorizedCumSum(_BaseIntervalIdentifier):
def _transform(self, series: pd.Series) -> List[int]:
"""First, get enumeration of all intervals (valid and invalid). Every
time a start or end marker is encountered, increase interval id by one.
However, shift the end marker by one to include the end marker in the
The end marker is shifted by one to include the end marker in the
current interval. This is realized via the cumulative sum of boolean
series of start markers and shifted end markers.
Expand Down
51 changes: 49 additions & 2 deletions src/pywrangler/wranglers/spark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
"""

from typing import Iterable, Union

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.column import Column

from pywrangler.util.sanitizer import ensure_tuple
from pywrangler.util.types import TYPE_ASCENDING, TYPE_COLUMNS
from pywrangler.wranglers.base import BaseWrangler

TYPE_OPT_COLUMN = Union[None, Iterable[Column]]


class SparkWrangler(BaseWrangler):
"""Contains methods common to all spark based wranglers.
Expand All @@ -16,6 +24,45 @@ class SparkWrangler(BaseWrangler):
def computation_engine(self):
return "spark"

@staticmethod
def validate_columns(df: DataFrame, columns: TYPE_COLUMNS):
"""Check that columns exist in dataframe and raise error if otherwise.
Parameters
----------
df: pyspark.sql.DataFrame
Dataframe to check against.
columns: Tuple[str]
Columns to be validated.
"""

if not columns:
return

columns = ensure_tuple(columns)

for column in columns:
if column not in df.columns:
raise ValueError('Column with name `{}` does not exist. '
'Please check parameter settings.'
.format(column))

@staticmethod
def prepare_orderby(order_columns: TYPE_COLUMNS,
ascending: TYPE_ASCENDING) -> TYPE_OPT_COLUMN:
"""Convenient function to return orderby columns in correct
ascending/descending order.
"""

if order_columns is None:
return []

zipped = zip(order_columns, ascending)
return [column if ascending else F.desc(column)
for column, ascending in zipped]


class SparkSingleNoFit(SparkWrangler):
"""Mixin class defining `fit` and `fit_transform` for all wranglers with
Expand All @@ -31,7 +78,7 @@ def fit(self, df: DataFrame):
Parameters
----------
df: pd.DataFrame
df: pyspark.sql.DataFrame
"""

Expand All @@ -46,7 +93,7 @@ def fit_transform(self, df: DataFrame) -> DataFrame:
Returns
-------
result: pd.DataFrame
result: pyspark.sql.DataFrame
"""

Expand Down
100 changes: 100 additions & 0 deletions src/pywrangler/wranglers/spark/interval_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""This module contains implementations of the interval identifier wrangler.
"""

from pyspark.sql import DataFrame, Window
from pyspark.sql import functions as F

from pywrangler.wranglers.interfaces import IntervalIdentifier
from pywrangler.wranglers.spark.base import SparkSingleNoFit


class VectorizedCumSum(SparkSingleNoFit, IntervalIdentifier):
"""Sophisticated approach avoiding python UDFs. However multiple windows
are necessary.
First, get enumeration of all intervals (valid and invalid). Every
time a start or end marker is encountered, increase interval id by one.
The end marker is shifted by one to include the end marker in the
current interval. This is realized via the cumulative sum of boolean
series of start markers and shifted end markers.
Second, separate valid from invalid intervals by ensuring the presence
of both start and end markers per interval id.
Third, numerate valid intervals starting with 1 and set invalid
intervals to 0.
"""

def validate_input(self, df: DataFrame):
"""Checks input data frame in regard to column names.
Parameters
----------
df: pyspark.sql.Dataframe
Dataframe to be validated.
"""

self.validate_columns(df, self.marker_column)
self.validate_columns(df, self.order_columns)
self.validate_columns(df, self.groupby_columns)

if self.order_columns is None:
raise ValueError("Please define an order column. Pyspark "
"dataframes have no implicit order unlike pandas "
"dataframes.")

def transform(self, df: DataFrame) -> DataFrame:
"""Extract interval ids from given dataframe.
Parameters
----------
df: pyspark.sql.Dataframe
Returns
-------
result: pyspark.sql.Dataframe
Same columns as original dataframe plus the new interval id column.
"""

# check input
self.validate_input(df)

# define window specs
orderby = self.prepare_orderby(self.order_columns, self.ascending)
groupby = self.groupby_columns or []

w_lag = Window.partitionBy(list(groupby)).orderBy(orderby)
w_id = Window.partitionBy(list(groupby) + [self.target_column_name])

# get boolean series with start and end markers
marker_col = F.col(self.marker_column)
bool_start = (marker_col == self.marker_start).cast("integer")
bool_end = (marker_col == self.marker_end).cast("integer")
bool_start_end = bool_start + bool_end

# shifting the close marker allows cumulative sum to include the end
bool_end_shift = F.lag(bool_end, default=1).over(w_lag).cast("integer")
bool_start_end_shift = bool_start + bool_end_shift

# get increasing ids for intervals (in/valid) with cumsum
ser_id = F.sum(bool_start_end_shift).over(w_lag)

# separate valid vs invalid: ids with start AND end marker are valid
bool_valid = F.sum(bool_start_end).over(w_id) == 2
valid_ids = F.when(bool_valid, ser_id).otherwise(0)

# re-numerate ids from 1 to x and fill invalid with 0
valid_ids_shift = F.lag(valid_ids, default=0).over(w_lag)
valid_ids_diff = valid_ids_shift - valid_ids
valid_ids_increase = (valid_ids_diff < 0).cast("integer")

renumerate = F.sum(valid_ids_increase).over(w_lag)
renumerate_adjusted = F.when(bool_valid, renumerate).otherwise(0)

# ser_id needs be created temporarily for renumerate_adjusted
return df.withColumn(self.target_column_name, ser_id) \
.withColumn(self.target_column_name, renumerate_adjusted)
54 changes: 54 additions & 0 deletions src/pywrangler/wranglers/spark/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""This module contains helper functions for testing.
"""

import pandas as pd
from pyspark.sql import DataFrame

from pywrangler.util.types import TYPE_COLUMNS

try:
from pandas.testing import assert_frame_equal
except ImportError:
from pandas.util.testing import assert_frame_equal


def assert_spark_pandas_equality(df_spark: DataFrame,
df_pandas: pd.DataFrame,
orderby: TYPE_COLUMNS = None):
"""Compare a spark and pandas dataframe in regard to content equality.
Spark dataframes do not have a specific index or column order due to their
distributed nature. In contrast, a test for equality for pandas dataframes
respects index and column order. Therefore, the test for equality between a
spark and pandas dataframe will ignore index and column order on purpose.
Testing spark dataframes content is most simple while converting to pandas
dataframes and having test data as pandas dataframes, too.
To ensure index order is ignored, both dataframes need be sorted by all or
given columns `orderby`.
Parameters
----------
df_spark: pyspark.sql.DataFrame
Spark dataframe to be tested for equality.
df_pandas: pd.DataFrame
Pandas dataframe to be tested for equality.
orderby: iterable, optional
Columns to be sorted for correct index order.
Returns
-------
None but asserts if dataframes are not equal.
"""

orderby = orderby or df_pandas.columns.tolist()

def prepare_compare(df):
return df.sort_values(orderby).reset_index(drop=True)

df_spark = prepare_compare(df_spark.toPandas())
df_pandas = prepare_compare(df_pandas)

assert_frame_equal(df_spark, df_pandas, check_like=True)
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ def pytest_collection_modifyitems(config, items):
"""

for skip_item in ("pyspark", "dask"):
tox_env = os.environ.get("PYWRANGLER_TEST_ENV", "").lower()

# if master version, all tests are run, no skipping required
if "master" in tox_env:
return

tox_env = os.environ.get("PYWRANGLER_TEST_ENV", "").lower()
for skip_item in ("pyspark", "dask"):
run_env = skip_item in tox_env
run_cmd = config.getoption("--{}".format(skip_item))

Expand Down
14 changes: 14 additions & 0 deletions tests/travis_coveralls_master.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

# Different pandas/pyspark/dask versions are tested separately to avoid
# irrelevant tests to be run. For example, no spark tests need to be run
# when pandas wranglers are tested on older pandas versions.

# However, code coverage drops due to many skipped tests. Therefore, there is a
# master version (marked via env variables) which includes all tests for
# pandas/pyspark/dask for the newest available versions which is subject to
# code coverage. Non master versions will not be included in code coverage.

if [[ $ENV_STRING == *"master"* ]]; then
coveralls --verbose
fi
2 changes: 1 addition & 1 deletion tests/travis_java_install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# selected as language. Ubuntu 14.04 does not work due to missing python 3.7
# support on TravisCI which does have Java 8 as default.

if [[ $ENV_STRING == *"spark"* ]]; then
if [[ $ENV_STRING == *"spark"* ]] || [[ $ENV_STRING == *"master"* ]]; then
# show current JAVA_HOME and java version
echo "Current JAVA_HOME: $JAVA_HOME"
echo "Current java -version:"
Expand Down
2 changes: 0 additions & 2 deletions tests/wranglers/pandas/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def test_pandas_wrangler_validate_columns_raises():
with pytest.raises(ValueError):
PandasWrangler.validate_columns(df, ("col3", "col1"))

# TODO: Update type hints for varialbe length tuples


def test_pandas_wrangler_validate_columns_not_raises():

Expand Down
18 changes: 18 additions & 0 deletions tests/wranglers/spark/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import pytest
import pandas as pd

pytestmark = pytest.mark.pyspark # noqa: E402
pyspark = pytest.importorskip("pyspark") # noqa: E402
Expand All @@ -15,3 +16,20 @@ def test_spark_base_wrangler_engine():
wrangler = SparkWrangler()

assert wrangler.computation_engine == "spark"


def test_spark_wrangler_validate_columns_raises(spark):

data = {"col1": [1, 2], "col2": [3, 4]}
df = spark.createDataFrame(pd.DataFrame(data))

with pytest.raises(ValueError):
SparkWrangler.validate_columns(df, ("col3", "col1"))


def test_spark_wrangler_validate_columns_not_raises(spark):

data = {"col1": [1, 2], "col2": [3, 4]}
df = spark.createDataFrame(pd.DataFrame(data))

SparkWrangler.validate_columns(df, ("col1", "col2"))
Loading

0 comments on commit 31b4133

Please sign in to comment.