# Setup

In [6]:
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

In [7]:
from pyspark import keyword_only
from pyspark.ml.pipeline import Pipeline, PipelineModel, Transformer
from pyspark.ml.param.shared import *
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable

In [8]:
from pyspark.ml.feature import Imputer

In [26]:
import pyspark.sql.functions as F

In [9]:
spark

In [10]:
df = sc \
    .parallelize([
        (1, 'a', 'A', 56., 175., 10.), 
        (2, 'a', 'B', 66., None, 92.), 
        (3, 'b', 'B', None, 182., 876.), 
        (4, 'c', None, 71., 171., None), 
        (5, None, 'B', 48., 173., None)]) \
    .toDF(["id", "x1", 'x2', 'x3', 'x4', 'x5'])

In [11]:
df.show()

+---+----+----+----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|
+---+----+----+----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|
|  2|   a|   B|66.0| null| 92.0|
|  3|   b|   B|null|182.0|876.0|
|  4|   c|null|71.0|171.0| null|
|  5|null|   B|48.0|173.0| null|
+---+----+----+----+-----+-----+



# Impute with Constant Value

In [12]:
class HasConstValue(Params):

    const_value = Param(Params._dummy(),
            "const_value", "string, double or dict")

    def __init__(self):
        super(HasConstValue, self).__init__()

    def setConstValue(self, value):
        return self._set(const_value=value)

    def getConstValue(self):
        return self.getOrDefault(self.const_value)

In [34]:
class ConstantImputer(Transformer, HasInputCols, HasOutputCols, HasConstValue, 
                      DefaultParamsReadable, DefaultParamsWritable):
    
    def _transform(self, dataset):
        const_val = self.getConstValue()
        if type(const_val) is dict:
            # fill dict
            return dataset.na.fill(const_val)
        else:
            xs = None
            try:
                xs = self.getInputCols()
            except KeyError as e:
                # fill constant string or number
                return dataset.na.fill(const_val)
            
            ys = xs
            try:
                ys = self.getOutputCols()
            except KeyError as e:
                pass
            
            filled_dataset = dataset
            for x, y in zip(xs, ys):
                filled_dataset = filled_dataset \
                    .withColumn(y, F.when(F.col(x).isNull(), const_val).otherwise(F.col(x)))
            
            return filled_dataset

In [35]:
ConstantImputer().setConstValue(0).transform(df).show()

+---+----+----+----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|
+---+----+----+----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|
|  2|   a|   B|66.0|  0.0| 92.0|
|  3|   b|   B| 0.0|182.0|876.0|
|  4|   c|null|71.0|171.0|  0.0|
|  5|null|   B|48.0|173.0|  0.0|
+---+----+----+----+-----+-----+



In [36]:
ConstantImputer().setConstValue('x').transform(df).show()

+---+---+---+----+-----+-----+
| id| x1| x2|  x3|   x4|   x5|
+---+---+---+----+-----+-----+
|  1|  a|  A|56.0|175.0| 10.0|
|  2|  a|  B|66.0| null| 92.0|
|  3|  b|  B|null|182.0|876.0|
|  4|  c|  x|71.0|171.0| null|
|  5|  x|  B|48.0|173.0| null|
+---+---+---+----+-----+-----+



In [37]:
ConstantImputer() \
    .setInputCols(['x1'])\
    .setConstValue('x') \
    .transform(df).show()

+---+---+----+----+-----+-----+
| id| x1|  x2|  x3|   x4|   x5|
+---+---+----+----+-----+-----+
|  1|  a|   A|56.0|175.0| 10.0|
|  2|  a|   B|66.0| null| 92.0|
|  3|  b|   B|null|182.0|876.0|
|  4|  c|null|71.0|171.0| null|
|  5|  x|   B|48.0|173.0| null|
+---+---+----+----+-----+-----+



In [38]:
ConstantImputer() \
    .setInputCols(['x1', 'x2']) \
    .setOutputCols(['x1_im', 'x2_im']) \
    .setConstValue('x') \
    .transform(df).show()

+---+----+----+----+-----+-----+-----+-----+
| id|  x1|  x2|  x3|   x4|   x5|x1_im|x2_im|
+---+----+----+----+-----+-----+-----+-----+
|  1|   a|   A|56.0|175.0| 10.0|    a|    A|
|  2|   a|   B|66.0| null| 92.0|    a|    B|
|  3|   b|   B|null|182.0|876.0|    b|    B|
|  4|   c|null|71.0|171.0| null|    c|    x|
|  5|null|   B|48.0|173.0| null|    x|    B|
+---+----+----+----+-----+-----+-----+-----+



In [39]:
ConstantImputer() \
    .setConstValue({'x1': 'x', 'x2': 'C', 'x3': 0, 'x4': 100, 'x5': 999}) \
    .transform(df).show()

+---+---+---+----+-----+-----+
| id| x1| x2|  x3|   x4|   x5|
+---+---+---+----+-----+-----+
|  1|  a|  A|56.0|175.0| 10.0|
|  2|  a|  B|66.0|100.0| 92.0|
|  3|  b|  B| 0.0|182.0|876.0|
|  4|  c|  C|71.0|171.0|999.0|
|  5|  x|  B|48.0|173.0|999.0|
+---+---+---+----+-----+-----+

