Skip to content

Commit

Permalink
Add wrangler_to_spark_transformer. Refactor and simplify
Browse files Browse the repository at this point in the history
`func_to_spark_transformer`. Restrict pipeline stages to be type of
`Transformer`, `PySparkWrangler` or callable. Add type annotations and
doc strings.
  • Loading branch information
Franz Wöllert committed Jul 15, 2019
1 parent dfe73a7 commit 5ff2ddd
Showing 1 changed file with 170 additions and 90 deletions.
260 changes: 170 additions & 90 deletions src/pywrangler/pyspark/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
import inspect
import re
from collections import OrderedDict
import copy

import pandas as pd
from pyspark.ml import PipelineModel, Transformer
from pyspark.ml.param.shared import Param, Params
from pyspark.sql import DataFrame

from pywrangler.pyspark.base import PySparkWrangler
from pywrangler.util._pprint import pretty_time_duration as fmt_time
from pywrangler.util._pprint import textwrap_docstring, truncate

from typing import Dict, Callable, Optional, Sequence, Union, Any
from typing import Dict, Callable, Optional, Sequence, Union, Any, cast
from collections.abc import KeysView

TYPE_STAGE = Union[Transformer, Any]

REGEX_STAGE = re.compile(r".*?\*\((\d+)\).*")
# types
TYPE_PARAM_DICT = Dict[str, Union[Callable, Param]]

# printing templates
DESC_H = "| ({idx}) - {identifier}, {cols} columns, stage {stage}, {cached}"
DESC_B = "| {text:76} |"
DESC_L = "+" + "-" * 78 + "+"
Expand All @@ -34,12 +37,24 @@

ERR_TYPE_ACCESS = "Value has incorrect type '{}' (integer or string allowed)."

REGEX_STAGE = re.compile(r".*?\*\((\d+)\).*")


def _create_getter_setter(name: str) -> Dict[str, Callable]:
"""Helper function to create getter and setter methods
for parameters of `Transformer` class for given parameter
name.
Parameters
----------
name: str
The name of the parameter.
Returns
-------
param_dict: Dict[str, Callable]
Dictionary containing the getter/setter methods for single parameter.
"""

def setter(self, value):
Expand All @@ -52,28 +67,22 @@ def getter(self):
"set{name}".format(name=name): setter}


def func_to_spark_transformer(func: Callable):
"""Convert a native python function into a pyspark `Transformer`
instance.
def _create_param_dict(parameters: KeysView) -> TYPE_PARAM_DICT:
"""Create getter/setter methods and Param attributes for given parameters
to comply `pyspark.ml.Transformer` API.
Temporarely creates a new sublcass of type `Transformer` during
runtime while ensuring that all keyword arguments of the input
function are mapped to corresponding `Param` values with
required getter and setter methods for the resulting `Transformer`
class.
Parameters
----------
parameters: KeysView
Contains the names of the parameters.
Returns an instance of the temporarely create `Transformer` subclass.
Returns
-------
param_dict: Dict[str, Union[Callable, Param]]
Dictionary containing the parameter setter/getter and Param attributes.
"""

class_name = func.__name__
class_bases = (Transformer,)
class_dict = {}

# overwrite transform method while taking care of kwargs
def _transform(self, df):
return func(df, **self.getParams())

def setParams(self, **kwargs):
return self._set(**kwargs)

Expand All @@ -82,30 +91,132 @@ def getParams(self):
kwargs = {key.name: value for key, value in params}
return kwargs

class_dict["_transform"] = _transform
class_dict["setParams"] = setParams
class_dict["getParams"] = getParams
class_dict["__doc__"] = func.__doc__

# get keyword arguments
signature = inspect.signature(func)
parameters = signature.parameters.values()
parameters = {x.name: x.default for x in parameters
if not x.default == inspect._empty}
param_dict = {"setParams": setParams,
"getParams": getParams}

# create setter/getter and Param instances
for parameter in parameters.keys():
class_dict.update(_create_getter_setter(parameter))
class_dict[parameter] = Param(Params._dummy(), parameter, "")
for parameter in parameters:
param_dict.update(_create_getter_setter(parameter))
param_dict[parameter] = Param(Params._dummy(), parameter, "")

return param_dict

# create class
transformer_class = type(class_name, class_bases, class_dict)

def _instantiate_transformer(cls_name: str,
cls_dict: Dict[str, Any],
cls_params: Dict[str, Any]) -> Transformer:
"""Create subclass of `pyspark.ml.Transformer` during runtime with name
`cls_name` and methods/attributes `cls_dict`. Create instance of it and
configure it with given parameters `cls_params`.
Parameters
----------
cls_name: str
Name of the class.
cls_dict: Dict[str, Any]
All methods/attributes of the class.
cls_params: Dict[str, Any]
All parameters to be set for a new instance of this class.
Returns
-------
transformer_instance: Transformer
"""

transformer_class = type(cls_name, (Transformer,), cls_dict)
transformer_instance = transformer_class()
transformer_instance._set(**parameters)
transformer_instance._set(**cls_params)

return transformer_instance


def wrangler_to_spark_transformer(wrangler: PySparkWrangler) -> Transformer:
"""Convert a `PySparkWrangler` into a pyspark `Transformer`.
Creates a deep copy of the original wrangler instance to leave it
unchanged. The original API is lost and the pyspark `Transformer` API
applies.
Temporarily creates a new sublcass of type `Transformer` during
runtime while ensuring that all keyword arguments of the wrangler
are mapped to corresponding `Param` values with required getter and setter
methods for the resulting `Transformer` class.
Returns an instance of the temporarily create `Transformer` subclass.
Parameters
----------
wrangler: PySparkWrangler
Instance of a pyspark wrangler.
Returns
-------
transformer: pyspark.ml.Transformer
"""

def _transform(self, df):
self._wrangler.set_params(**self.getParams())
return self._wrangler.transform(df)

cls_name = wrangler.__class__.__name__
cls_dict = {"_wranlger": copy.deepcopy(wrangler),
"_transform": _transform,
"__doc__": wrangler.__doc__}

# get parameters
params = wrangler.get_params()
params_dict = _create_param_dict(params.keys())
cls_dict.update(params_dict)

return _instantiate_transformer(cls_name, cls_dict, params)


def func_to_spark_transformer(func: Callable) -> Transformer:
"""Convert a native python function into a pyspark `Transformer`
instance. Expects the first parameter to be positional representing the
input dataframe.
Temporarily creates a new sublcass of type `Transformer` during
runtime while ensuring that all keyword arguments of the input
function are mapped to corresponding `Param` values with
required getter and setter methods for the resulting `Transformer`
class.
Returns an instance of the temporarily create `Transformer` subclass.
Parameters
----------
func: Callable
Native python function.
Returns
-------
transformer: pyspark.ml.Transformer
"""

def _transform(self, df):
return self._func(df, **self.getParams())

cls_name = func.__name__
cls_dict = {"_func": func,
"_transform": _transform,
"__doc__": func.__doc__}

# get parameters
signature = inspect.signature(func)
params = signature.parameters.values()
params = {x.name: x.default for x in params
if not x.default == inspect._empty}

params_dict = _create_param_dict(params.keys())
cls_dict.update(params_dict)

return _instantiate_transformer(cls_name, cls_dict, params)


class Pipeline(PipelineModel):
"""Represents a compiled pipeline with transformers and fitted models.
Expand Down Expand Up @@ -141,7 +252,7 @@ class Pipeline(PipelineModel):
"""

def __init__(self, stages: Sequence, doc: Optional[str]=None):
def __init__(self, stages: Sequence, doc: Optional[str] = None):
"""Instantiate pipeline. Convert functions into `Transformer`
instances if necessary.
Expand All @@ -163,8 +274,7 @@ def __init__(self, stages: Sequence, doc: Optional[str]=None):
self._stage_profiles = OrderedDict()

for stage in self.stages:
identifier = self._create_stage_identifier(stage)
self._stage_mapping[identifier] = stage
self._stage_mapping[stage.uid] = stage

# overwrite class doc string if pipe doc is explicitly provided
if doc:
Expand Down Expand Up @@ -238,7 +348,7 @@ def describe(self) -> None:
# print description
print(DESC_L, header, DESC_L, *docs, sep="\n")

def profile(self, verbose: bool=False) -> None:
def profile(self, verbose: bool = False) -> None:
"""Profiles each stage of the pipeline with and without caching
enabled. Total, partial and cache times are reported. Partial time is
computed as the current total time minus the previous total time. If
Expand Down Expand Up @@ -294,7 +404,7 @@ def _unpersist_dataframes(self, verbose=None):

df.unpersist(blocking=True)

def _profile_without_caching(self, verbose: bool=None) -> None:
def _profile_without_caching(self, verbose: bool = None) -> None:
"""Make profile without caching enabled. Store total and partial time,
counts and execution plan stage.
Expand Down Expand Up @@ -330,7 +440,7 @@ def _profile_without_caching(self, verbose: bool=None) -> None:

temp_total_time = prof["total_time"]

def _profile_with_caching(self, verbose: bool=None) -> None:
def _profile_with_caching(self, verbose: bool = None) -> None:
"""Make profile with caching enabled for each stage. The input
dataframe will not be cached.
Expand Down Expand Up @@ -402,7 +512,7 @@ def _profile_report(self) -> None:
lines.append(PROF_L)
print(*lines, sep="\n")

def __getitem__(self, value: Union[str, int]) -> TYPE_STAGE:
def __getitem__(self, value: Union[str, int]) -> Transformer:
"""Get stage by index location or label access.
Index location requires integer value. Label access requires string
Expand Down Expand Up @@ -467,7 +577,7 @@ def _is_transformed(self) -> None:
"`df` needs to be supplied.")

@staticmethod
def _is_cached(stage: TYPE_STAGE) -> bool:
def _is_cached(stage: Transformer) -> bool:
"""Check if given stage has caching enabled or not via `getIsCached`.
Parameters
Expand Down Expand Up @@ -511,13 +621,13 @@ def _identify_stage(self, identifier: str) -> str:
raise ValueError(
"Stage with identifier `{identifier}` not found. "
"Possible identifiers are {options}."
.format(identifier=identifier,
options=self._stage_mapping.keys()))
.format(identifier=identifier,
options=self._stage_mapping.keys()))

if len(stages) > 1:
raise ValueError(
"Identifier is ambiguous. More than one stage identified: {}"
.format(stages))
.format(stages))

return stages[0]

Expand Down Expand Up @@ -583,35 +693,11 @@ def _profile_stage(self, df: DataFrame) -> Dict[str, Union[str, float]]:
"stage": stage}

@staticmethod
def _create_stage_identifier(stage: TYPE_STAGE) -> str:
"""Given different types of stages, create a unique identifier for
each stage. Valid pyspark `Transformer` have an uid. Other objects
will use class name and id.
Parameters
----------
stage: pyspark.ml.Transformer
A stage for which a uniqe identifier is returned.
Returns
-------
identifier: str
"""

try:
return stage.uid
except AttributeError:
if inspect.isclass(stage):
return "{}_{}".format(stage.__name__, id(stage))
else:
return "{}_{}".format(stage.__class__.__name__, id(stage))

@staticmethod
def _check_convert_transformer(stage: Any) -> TYPE_STAGE:
"""Ensure given stage is suitable for pipeline usage while checking
for `transform` attribute. If not and stage is a function, convert
into `Transformer` instance.
def _check_convert_transformer(stage: Any) -> Transformer:
"""Ensure given stage is suitable for pipeline usage while allowing
only instances of type `Transformer`, `Wrangler` and native python
functions. Objects which are not of type `Transformer` will be
converted into it.
Parameters
----------
Expand All @@ -623,23 +709,17 @@ def _check_convert_transformer(stage: Any) -> TYPE_STAGE:
converted: pyspark.ml.Transformer
Object with a `transform` method.
ToDo: Add conversion for pywrangler.pyspark instances
ToDo: Allow only Transformer, Wrangler, functions
"""

if hasattr(stage, "transform"):
if callable(stage.transform):
return stage
else:
raise ValueError(
"Transform method of stage {} is not callable."
.format(stage))

if isinstance(stage, Transformer):
return Transformer
elif isinstance(stage, PySparkWrangler):
return wrangler_to_spark_transformer(stage)
elif inspect.isfunction(stage):
return func_to_spark_transformer(stage)

else:
raise ValueError(
"Stage '{}' needs to implement `transform` method or "
"has to be a function.".format(stage))
"Stage needs to be a `Transformer`, `PySparkWrangler` "
"or a native python function. However, '{}' was given."
.format(type(stage)))

0 comments on commit 5ff2ddd

Please sign in to comment.