From 5ff2ddd7f87a2705e247b6cc3b349583c7c75dbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20W=C3=B6llert?= Date: Mon, 15 Jul 2019 14:59:07 +0200 Subject: [PATCH] Add `wrangler_to_spark_transformer`. Refactor and simplify `func_to_spark_transformer`. Restrict pipeline stages to be type of `Transformer`, `PySparkWrangler` or callable. Add type annotations and doc strings. --- src/pywrangler/pyspark/pipeline.py | 260 +++++++++++++++++++---------- 1 file changed, 170 insertions(+), 90 deletions(-) diff --git a/src/pywrangler/pyspark/pipeline.py b/src/pywrangler/pyspark/pipeline.py index cfdff54..218d3ae 100644 --- a/src/pywrangler/pyspark/pipeline.py +++ b/src/pywrangler/pyspark/pipeline.py @@ -5,21 +5,24 @@ import inspect import re from collections import OrderedDict +import copy import pandas as pd from pyspark.ml import PipelineModel, Transformer from pyspark.ml.param.shared import Param, Params from pyspark.sql import DataFrame +from pywrangler.pyspark.base import PySparkWrangler from pywrangler.util._pprint import pretty_time_duration as fmt_time from pywrangler.util._pprint import textwrap_docstring, truncate -from typing import Dict, Callable, Optional, Sequence, Union, Any +from typing import Dict, Callable, Optional, Sequence, Union, Any, cast +from collections.abc import KeysView -TYPE_STAGE = Union[Transformer, Any] - -REGEX_STAGE = re.compile(r".*?\*\((\d+)\).*") +# types +TYPE_PARAM_DICT = Dict[str, Union[Callable, Param]] +# printing templates DESC_H = "| ({idx}) - {identifier}, {cols} columns, stage {stage}, {cached}" DESC_B = "| {text:76} |" DESC_L = "+" + "-" * 78 + "+" @@ -34,12 +37,24 @@ ERR_TYPE_ACCESS = "Value has incorrect type '{}' (integer or string allowed)." +REGEX_STAGE = re.compile(r".*?\*\((\d+)\).*") + def _create_getter_setter(name: str) -> Dict[str, Callable]: """Helper function to create getter and setter methods for parameters of `Transformer` class for given parameter name. + Parameters + ---------- + name: str + The name of the parameter. + + Returns + ------- + param_dict: Dict[str, Callable] + Dictionary containing the getter/setter methods for single parameter. + """ def setter(self, value): @@ -52,28 +67,22 @@ def getter(self): "set{name}".format(name=name): setter} -def func_to_spark_transformer(func: Callable): - """Convert a native python function into a pyspark `Transformer` - instance. +def _create_param_dict(parameters: KeysView) -> TYPE_PARAM_DICT: + """Create getter/setter methods and Param attributes for given parameters + to comply `pyspark.ml.Transformer` API. - Temporarely creates a new sublcass of type `Transformer` during - runtime while ensuring that all keyword arguments of the input - function are mapped to corresponding `Param` values with - required getter and setter methods for the resulting `Transformer` - class. + Parameters + ---------- + parameters: KeysView + Contains the names of the parameters. - Returns an instance of the temporarely create `Transformer` subclass. + Returns + ------- + param_dict: Dict[str, Union[Callable, Param]] + Dictionary containing the parameter setter/getter and Param attributes. """ - class_name = func.__name__ - class_bases = (Transformer,) - class_dict = {} - - # overwrite transform method while taking care of kwargs - def _transform(self, df): - return func(df, **self.getParams()) - def setParams(self, **kwargs): return self._set(**kwargs) @@ -82,30 +91,132 @@ def getParams(self): kwargs = {key.name: value for key, value in params} return kwargs - class_dict["_transform"] = _transform - class_dict["setParams"] = setParams - class_dict["getParams"] = getParams - class_dict["__doc__"] = func.__doc__ - - # get keyword arguments - signature = inspect.signature(func) - parameters = signature.parameters.values() - parameters = {x.name: x.default for x in parameters - if not x.default == inspect._empty} + param_dict = {"setParams": setParams, + "getParams": getParams} # create setter/getter and Param instances - for parameter in parameters.keys(): - class_dict.update(_create_getter_setter(parameter)) - class_dict[parameter] = Param(Params._dummy(), parameter, "") + for parameter in parameters: + param_dict.update(_create_getter_setter(parameter)) + param_dict[parameter] = Param(Params._dummy(), parameter, "") + + return param_dict - # create class - transformer_class = type(class_name, class_bases, class_dict) + +def _instantiate_transformer(cls_name: str, + cls_dict: Dict[str, Any], + cls_params: Dict[str, Any]) -> Transformer: + """Create subclass of `pyspark.ml.Transformer` during runtime with name + `cls_name` and methods/attributes `cls_dict`. Create instance of it and + configure it with given parameters `cls_params`. + + Parameters + ---------- + cls_name: str + Name of the class. + cls_dict: Dict[str, Any] + All methods/attributes of the class. + cls_params: Dict[str, Any] + All parameters to be set for a new instance of this class. + + Returns + ------- + transformer_instance: Transformer + + """ + + transformer_class = type(cls_name, (Transformer,), cls_dict) transformer_instance = transformer_class() - transformer_instance._set(**parameters) + transformer_instance._set(**cls_params) return transformer_instance +def wrangler_to_spark_transformer(wrangler: PySparkWrangler) -> Transformer: + """Convert a `PySparkWrangler` into a pyspark `Transformer`. + + Creates a deep copy of the original wrangler instance to leave it + unchanged. The original API is lost and the pyspark `Transformer` API + applies. + + Temporarily creates a new sublcass of type `Transformer` during + runtime while ensuring that all keyword arguments of the wrangler + are mapped to corresponding `Param` values with required getter and setter + methods for the resulting `Transformer` class. + + Returns an instance of the temporarily create `Transformer` subclass. + + Parameters + ---------- + wrangler: PySparkWrangler + Instance of a pyspark wrangler. + + Returns + ------- + transformer: pyspark.ml.Transformer + + """ + + def _transform(self, df): + self._wrangler.set_params(**self.getParams()) + return self._wrangler.transform(df) + + cls_name = wrangler.__class__.__name__ + cls_dict = {"_wranlger": copy.deepcopy(wrangler), + "_transform": _transform, + "__doc__": wrangler.__doc__} + + # get parameters + params = wrangler.get_params() + params_dict = _create_param_dict(params.keys()) + cls_dict.update(params_dict) + + return _instantiate_transformer(cls_name, cls_dict, params) + + +def func_to_spark_transformer(func: Callable) -> Transformer: + """Convert a native python function into a pyspark `Transformer` + instance. Expects the first parameter to be positional representing the + input dataframe. + + Temporarily creates a new sublcass of type `Transformer` during + runtime while ensuring that all keyword arguments of the input + function are mapped to corresponding `Param` values with + required getter and setter methods for the resulting `Transformer` + class. + + Returns an instance of the temporarily create `Transformer` subclass. + + Parameters + ---------- + func: Callable + Native python function. + + Returns + ------- + transformer: pyspark.ml.Transformer + + """ + + def _transform(self, df): + return self._func(df, **self.getParams()) + + cls_name = func.__name__ + cls_dict = {"_func": func, + "_transform": _transform, + "__doc__": func.__doc__} + + # get parameters + signature = inspect.signature(func) + params = signature.parameters.values() + params = {x.name: x.default for x in params + if not x.default == inspect._empty} + + params_dict = _create_param_dict(params.keys()) + cls_dict.update(params_dict) + + return _instantiate_transformer(cls_name, cls_dict, params) + + class Pipeline(PipelineModel): """Represents a compiled pipeline with transformers and fitted models. @@ -141,7 +252,7 @@ class Pipeline(PipelineModel): """ - def __init__(self, stages: Sequence, doc: Optional[str]=None): + def __init__(self, stages: Sequence, doc: Optional[str] = None): """Instantiate pipeline. Convert functions into `Transformer` instances if necessary. @@ -163,8 +274,7 @@ def __init__(self, stages: Sequence, doc: Optional[str]=None): self._stage_profiles = OrderedDict() for stage in self.stages: - identifier = self._create_stage_identifier(stage) - self._stage_mapping[identifier] = stage + self._stage_mapping[stage.uid] = stage # overwrite class doc string if pipe doc is explicitly provided if doc: @@ -238,7 +348,7 @@ def describe(self) -> None: # print description print(DESC_L, header, DESC_L, *docs, sep="\n") - def profile(self, verbose: bool=False) -> None: + def profile(self, verbose: bool = False) -> None: """Profiles each stage of the pipeline with and without caching enabled. Total, partial and cache times are reported. Partial time is computed as the current total time minus the previous total time. If @@ -294,7 +404,7 @@ def _unpersist_dataframes(self, verbose=None): df.unpersist(blocking=True) - def _profile_without_caching(self, verbose: bool=None) -> None: + def _profile_without_caching(self, verbose: bool = None) -> None: """Make profile without caching enabled. Store total and partial time, counts and execution plan stage. @@ -330,7 +440,7 @@ def _profile_without_caching(self, verbose: bool=None) -> None: temp_total_time = prof["total_time"] - def _profile_with_caching(self, verbose: bool=None) -> None: + def _profile_with_caching(self, verbose: bool = None) -> None: """Make profile with caching enabled for each stage. The input dataframe will not be cached. @@ -402,7 +512,7 @@ def _profile_report(self) -> None: lines.append(PROF_L) print(*lines, sep="\n") - def __getitem__(self, value: Union[str, int]) -> TYPE_STAGE: + def __getitem__(self, value: Union[str, int]) -> Transformer: """Get stage by index location or label access. Index location requires integer value. Label access requires string @@ -467,7 +577,7 @@ def _is_transformed(self) -> None: "`df` needs to be supplied.") @staticmethod - def _is_cached(stage: TYPE_STAGE) -> bool: + def _is_cached(stage: Transformer) -> bool: """Check if given stage has caching enabled or not via `getIsCached`. Parameters @@ -511,13 +621,13 @@ def _identify_stage(self, identifier: str) -> str: raise ValueError( "Stage with identifier `{identifier}` not found. " "Possible identifiers are {options}." - .format(identifier=identifier, - options=self._stage_mapping.keys())) + .format(identifier=identifier, + options=self._stage_mapping.keys())) if len(stages) > 1: raise ValueError( "Identifier is ambiguous. More than one stage identified: {}" - .format(stages)) + .format(stages)) return stages[0] @@ -583,35 +693,11 @@ def _profile_stage(self, df: DataFrame) -> Dict[str, Union[str, float]]: "stage": stage} @staticmethod - def _create_stage_identifier(stage: TYPE_STAGE) -> str: - """Given different types of stages, create a unique identifier for - each stage. Valid pyspark `Transformer` have an uid. Other objects - will use class name and id. - - Parameters - ---------- - stage: pyspark.ml.Transformer - A stage for which a uniqe identifier is returned. - - Returns - ------- - identifier: str - - """ - - try: - return stage.uid - except AttributeError: - if inspect.isclass(stage): - return "{}_{}".format(stage.__name__, id(stage)) - else: - return "{}_{}".format(stage.__class__.__name__, id(stage)) - - @staticmethod - def _check_convert_transformer(stage: Any) -> TYPE_STAGE: - """Ensure given stage is suitable for pipeline usage while checking - for `transform` attribute. If not and stage is a function, convert - into `Transformer` instance. + def _check_convert_transformer(stage: Any) -> Transformer: + """Ensure given stage is suitable for pipeline usage while allowing + only instances of type `Transformer`, `Wrangler` and native python + functions. Objects which are not of type `Transformer` will be + converted into it. Parameters ---------- @@ -623,23 +709,17 @@ def _check_convert_transformer(stage: Any) -> TYPE_STAGE: converted: pyspark.ml.Transformer Object with a `transform` method. - ToDo: Add conversion for pywrangler.pyspark instances - ToDo: Allow only Transformer, Wrangler, functions - """ - if hasattr(stage, "transform"): - if callable(stage.transform): - return stage - else: - raise ValueError( - "Transform method of stage {} is not callable." - .format(stage)) - + if isinstance(stage, Transformer): + return Transformer + elif isinstance(stage, PySparkWrangler): + return wrangler_to_spark_transformer(stage) elif inspect.isfunction(stage): return func_to_spark_transformer(stage) else: raise ValueError( - "Stage '{}' needs to implement `transform` method or " - "has to be a function.".format(stage)) + "Stage needs to be a `Transformer`, `PySparkWrangler` " + "or a native python function. However, '{}' was given." + .format(type(stage)))