# Setup

In [6]:
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

In [7]:
from pyspark import keyword_only
from pyspark.ml.pipeline import Pipeline, PipelineModel, Transformer
from pyspark.ml.param.shared import *
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable

In [8]:
from pyspark.ml.feature import Imputer

In [9]:
spark

In [10]:
df = sc \
    .parallelize([
        (1, 'a', 'A', 56., 175., 10.), 
        (2, 'a', 'B', 66., None, 92.), 
        (3, 'b', 'B', None, 182., 876.), 
        (4, 'c', None, 71., 171., None), 
        (5, None, 'B', 48., 173., None)]) \
    .toDF(["id", "x1", 'x2', 'x3', 'x4', 'x5'])

In [11]:
df.show()

+---+----+----+----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|
+---+----+----+----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|
|  2|   a|   B|66.0| null| 92.0|
|  3|   b|   B|null|182.0|876.0|
|  4|   c|null|71.0|171.0| null|
|  5|null|   B|48.0|173.0| null|
+---+----+----+----+-----+-----+



# Impute with Constant Value

In [200]:
impute_mode = ImputeCategoricalWithMode() \
    .setInputCols(['x1', 'x2']) \
    .setOutputCols(['x1_im', 'x2_im'])

In [201]:
impute_model = impute_mode.fit(df)

In [202]:
impute_model.transform(df).show()

+---+----+----+-----+-----+
| id|  x1|  x2|x1_im|x2_im|
+---+----+----+-----+-----+
|  1|   a|   A|    a|    A|
|  2|   a|   B|    a|    B|
|  3|   b|   B|    b|    B|
|  4|   c|null|    c|    B|
|  4|null|   B|    a|    B|
+---+----+----+-----+-----+



# Impute Pipeline

In [16]:
class HasConstValue(Params):

    const_value = Param(Params._dummy(),
            "const_value", "string, double or dict")

    def __init__(self):
        super(HasConstValue, self).__init__()

    def setConstValue(self, value):
        return self._set(const_value=value)

    def getConstValue(self):
        return self.getOrDefault(self.const_value)

In [18]:
class ConstantImputer(Transformer, HasConstValue, DefaultParamsReadable, DefaultParamsWritable):
    
    def _transform(self, dataset):
        const_val = self.getConstValue()
        return dataset.na.fill(const_val)

In [19]:
ConstantImputer().setConstValue(0).transform(df).show()

+---+----+----+----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|
+---+----+----+----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|
|  2|   a|   B|66.0|  0.0| 92.0|
|  3|   b|   B| 0.0|182.0|876.0|
|  4|   c|null|71.0|171.0|  0.0|
|  5|null|   B|48.0|173.0|  0.0|
+---+----+----+----+-----+-----+



In [20]:
ConstantImputer().setConstValue('x').transform(df).show()

+---+---+---+----+-----+-----+
| id| x1| x2|  x3|   x4|   x5|
+---+---+---+----+-----+-----+
|  1|  a|  A|56.0|175.0| 10.0|
|  2|  a|  B|66.0| null| 92.0|
|  3|  b|  B|null|182.0|876.0|
|  4|  c|  x|71.0|171.0| null|
|  5|  x|  B|48.0|173.0| null|
+---+---+---+----+-----+-----+



In [21]:
ConstantImputer().setConstValue({'x1': 'x', 'x2': 'C', 'x3': 0, 'x4': 100}).transform(df).show()

+---+---+---+----+-----+-----+
| id| x1| x2|  x3|   x4|   x5|
+---+---+---+----+-----+-----+
|  1|  a|  A|56.0|175.0| 10.0|
|  2|  a|  B|66.0|100.0| 92.0|
|  3|  b|  B| 0.0|182.0|876.0|
|  4|  c|  C|71.0|171.0| null|
|  5|  x|  B|48.0|173.0| null|
+---+---+---+----+-----+-----+

