# Setup

In [1]:
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

In [2]:
from pyspark import keyword_only
from pyspark.ml.pipeline import Pipeline, PipelineModel, Transformer
from pyspark.ml.param.shared import *
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable

In [3]:
from pyspark.ml.feature import Imputer

In [4]:
spark

In [5]:
df = sc \
    .parallelize([
        (1, 'a', 'A', 56., 175., 10.), 
        (2, 'a', 'B', 66., None, 92.), 
        (3, 'b', 'B', None, 182., 876.), 
        (4, 'c', None, 71., 171., None), 
        (5, None, 'B', 48., 173., None)]) \
    .toDF(["id", "x1", 'x2', 'x3', 'x4', 'x5'])

In [6]:
df.show()

+---+----+----+----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|
+---+----+----+----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|
|  2|   a|   B|66.0| null| 92.0|
|  3|   b|   B|null|182.0|876.0|
|  4|   c|null|71.0|171.0| null|
|  5|null|   B|48.0|173.0| null|
+---+----+----+----+-----+-----+



# Impute with Constant Value

In [7]:
class HasConstValue(Params):

    const_value = Param(Params._dummy(),
            "const_value", "string, double or dict")

    def __init__(self):
        super(HasConstValue, self).__init__()

    def setConstValue(self, value):
        return self._set(const_value=value)

    def getConstValue(self):
        return self.getOrDefault(self.const_value)

In [9]:
class ConstantImputer(Transformer, HasInputCols, HasConstValue, DefaultParamsReadable, DefaultParamsWritable):
    
    def _transform(self, dataset):
        const_val = self.getConstValue()
        try:
            xs = self.getInputCols()
            if type(const_val) is dict:
                raise Exception('Multiple fields can only be filled with single value.')
            val_dict = {x: const_val for x in xs}
            return dataset.na.fill(val_dict)
        except KeyError as e:
            return dataset.na.fill(const_val)

In [10]:
ConstantImputer().setConstValue(0).transform(df).show()

+---+----+----+----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|
+---+----+----+----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|
|  2|   a|   B|66.0|  0.0| 92.0|
|  3|   b|   B| 0.0|182.0|876.0|
|  4|   c|null|71.0|171.0|  0.0|
|  5|null|   B|48.0|173.0|  0.0|
+---+----+----+----+-----+-----+



In [11]:
ConstantImputer().setConstValue('x').transform(df).show()

+---+---+---+----+-----+-----+
| id| x1| x2|  x3|   x4|   x5|
+---+---+---+----+-----+-----+
|  1|  a|  A|56.0|175.0| 10.0|
|  2|  a|  B|66.0| null| 92.0|
|  3|  b|  B|null|182.0|876.0|
|  4|  c|  x|71.0|171.0| null|
|  5|  x|  B|48.0|173.0| null|
+---+---+---+----+-----+-----+



In [12]:
ConstantImputer().setInputCols(['x1']).setConstValue('x').transform(df).show()

+---+---+----+----+-----+-----+
| id| x1|  x2|  x3|   x4|   x5|
+---+---+----+----+-----+-----+
|  1|  a|   A|56.0|175.0| 10.0|
|  2|  a|   B|66.0| null| 92.0|
|  3|  b|   B|null|182.0|876.0|
|  4|  c|null|71.0|171.0| null|
|  5|  x|   B|48.0|173.0| null|
+---+---+----+----+-----+-----+



In [13]:
#ConstantImputer().setInputCols(['x1']).setConstValue({'s': 1}).transform(df).show()

In [14]:
ConstantImputer().setConstValue({'x1': 'x', 'x2': 'C', 'x3': 0, 'x4': 100}).transform(df).show()

+---+---+---+----+-----+-----+
| id| x1| x2|  x3|   x4|   x5|
+---+---+---+----+-----+-----+
|  1|  a|  A|56.0|175.0| 10.0|
|  2|  a|  B|66.0|100.0| 92.0|
|  3|  b|  B| 0.0|182.0|876.0|
|  4|  c|  C|71.0|171.0| null|
|  5|  x|  B|48.0|173.0| null|
+---+---+---+----+-----+-----+

