In [0]:
from pyspark.ml import Transformer
from pyspark.sql.functions import when

In [0]:
from pyspark.ml import Transformer
from pyspark.sql.functions import when

class Winsorizer(Transformer):
    """
    Transformer that caps values in each of the specified inputCols at given quantiles.
    """
    def __init__(self, inputCols, outputCols, lowerQuantile=0.05, upperQuantile=0.95):
        super(Winsorizer, self).__init__()
        assert len(inputCols) == len(outputCols), "inputCols and outputCols must match lengths"
        self.inputCols = inputCols
        self.outputCols = outputCols
        self.lowerQuantile = lowerQuantile
        self.upperQuantile = upperQuantile

    def _transform(self, dataset):
        df = dataset
        # For each column, compute bounds on its training data and cap extremes
        for inp, out in zip(self.inputCols, self.outputCols):
            bounds = df.approxQuantile(inp, [self.lowerQuantile, self.upperQuantile], 0.0)
            low, high = bounds[0], bounds[1]
            df = df.withColumn(
                out,
                when(df[inp] < low, low)
                .when(df[inp] > high, high)
                .otherwise(df[inp])
            )
        return df