# Setup

In [1]:
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

In [2]:
from pyspark.ml.pipeline import Estimator, Model, Pipeline, PipelineModel
from pyspark.ml.param.shared import *
import pyspark.sql.functions as F

In [3]:
spark, sc, sql

(<pyspark.sql.session.SparkSession at 0x109b16090>,
 <pyspark.context.SparkContext at 0x109a24e50>,
 <bound method SparkSession.sql of <pyspark.sql.session.SparkSession object at 0x109b16090>>)

# Sample Dataset

In [4]:
df = sc.parallelize([
        (1, 'a', 'A', None), 
        (2, 'a', 'B', 30.3), 
        (3, 'b', 'B', 27.8), 
        (4, 'c', None, 31.2), 
        (5, None, 'B', 32.5)]) \
    .toDF(["id", "x1", 'x2', 'score'])

In [5]:
df.show()

+---+----+----+-----+
| id|  x1|  x2|score|
+---+----+----+-----+
|  1|   a|   A| null|
|  2|   a|   B| 30.3|
|  3|   b|   B| 27.8|
|  4|   c|null| 31.2|
|  5|null|   B| 32.5|
+---+----+----+-----+



# String Disassembler

In [10]:
class HasOutputColsPrefix(Params):

    output_prefix = Param(Params._dummy(), "output_prefix", 
                         "prefix for every output column name",
                         typeConverter=TypeConverters.toString)

    def __init__(self):
        super(HasOutputColsPrefix, self).__init__()
        self._setDefault(output_prefix='is')
    
    def setOutputColsPrefix(self, value):
        return self._set(output_prefix=value)

    def getOutputColsPrefix(self):
        return self.getOrDefault(self.output_prefix)

In [13]:
class HasFieldValues(Params):

    values = Param(Params._dummy(), "values", 
                         "all possible values for a field",
                         typeConverter=TypeConverters.toList)
    
    fields = Param(Params._dummy(), "fields", 
                         "new fields",
                         typeConverter=TypeConverters.toList)

    def __init__(self):
        super(HasFieldValues, self).__init__()
        
    def setValues(self, value):
        return self._set(values=value)

    def getValues(self):
        return self.getOrDefault(self.values)

    def setFields(self, value):
        return self._set(fields=value)

    def getFields(self):
        return self.getOrDefault(self.fields)

In [14]:
class FillMode(Params):

    fill_mode = Param(Params._dummy(), "fill_mode", 
                         "should disassembler fill mode first",
                         typeConverter=TypeConverters.toBoolean)

    def __init__(self):
        super(FillMode, self).__init__()
        self._setDefault(fill_mode=False)

    def setFillMode(self, value):
        return self._set(fill_mode=value)

    def getFillMode(self):
        return self.getOrDefault(self.fill_mode)

In [27]:
class StringDisassembleModel(Model, HasInputCol, 
                             HasFieldValues, FillMode):
    
    def getMode(self):
        values = self.getValues()
        return None if len(values) == 0 else values[0]
    
    def _transform(self, dataset):
        x = self.getInputCol()
        fields = self.getFields()
        values = self.getValues()
        fill_mode = self.getFillMode()
        mode = self.getMode()
        
        new_df = dataset
        for f, v in zip(fields, values):
            null_cond = None
            if fill_mode:
                null_cond = (F.lit(mode) == F.lit(v)).cast('double')
            else:
                null_cond = 0
        
            new_df = new_df.withColumn(f, F.when(F.isnull(x), null_cond) \
                    .otherwise((F.col(x) == F.lit(v)).cast('double')))
        
        return new_df

In [28]:
class StringDisassembler(Estimator, HasInputCol, 
                         HasOutputColsPrefix, FillMode):
    
    def get_values(self, dataset):
        x = self.getInputCol()
        # values ordered by count (mode)
        values = dataset \
            .where('{} is not null'.format(x)) \
            .groupBy(x) \
            .agg(F.count('*').alias('count')) \
            .orderBy(F.desc('count')) \
            .rdd.map(lambda r: r[x]) \
            .collect()
            
        return values
    
    def get_fields(self, values):
        x = self.getInputCol()
        prefix = self.getOutputColsPrefix()
        return ['{}_{}_{}'.format(prefix, x, v) for v in values]
    
    def _fit(self, dataset):
        x = self.getInputCol()
        values = self.get_values(dataset)
        fields = self.get_fields(values)
        model = StringDisassembleModel() \
            .setInputCol(x) \
            .setFields(fields) \
            .setValues(values) \
            .setFillMode(self.getFillMode())

        return model

# Try StringDisassembler

In [29]:
df.show()

+---+----+----+-----+
| id|  x1|  x2|score|
+---+----+----+-----+
|  1|   a|   A| null|
|  2|   a|   B| 30.3|
|  3|   b|   B| 27.8|
|  4|   c|null| 31.2|
|  5|null|   B| 32.5|
+---+----+----+-----+



In [30]:
sd_model = StringDisassembler() \
    .setInputCol('x1') \
    .fit(df)

In [31]:
sd_model.getMode(), sd_model.getFields(), sd_model.getValues()

(u'a', [u'is_x1_a', u'is_x1_c', u'is_x1_b'], [u'a', u'c', u'b'])

In [32]:
sd_model.transform(df).show()

+---+----+----+-----+-------+-------+-------+
| id|  x1|  x2|score|is_x1_a|is_x1_c|is_x1_b|
+---+----+----+-----+-------+-------+-------+
|  1|   a|   A| null|    1.0|    0.0|    0.0|
|  2|   a|   B| 30.3|    1.0|    0.0|    0.0|
|  3|   b|   B| 27.8|    0.0|    0.0|    1.0|
|  4|   c|null| 31.2|    0.0|    1.0|    0.0|
|  5|null|   B| 32.5|    0.0|    0.0|    0.0|
+---+----+----+-----+-------+-------+-------+



In [33]:
sd_model2 = StringDisassembler() \
    .setInputCol('x2') \
    .setFillMode(True) \
    .fit(df)

In [34]:
sd_model2.transform(df).show()

+---+----+----+-----+-------+-------+
| id|  x1|  x2|score|is_x2_B|is_x2_A|
+---+----+----+-----+-------+-------+
|  1|   a|   A| null|    0.0|    1.0|
|  2|   a|   B| 30.3|    1.0|    0.0|
|  3|   b|   B| 27.8|    1.0|    0.0|
|  4|   c|null| 31.2|    1.0|    0.0|
|  5|null|   B| 32.5|    1.0|    0.0|
+---+----+----+-----+-------+-------+

