In [22]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col,when
import pyspark.sql.types as tp

from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer

spark=SparkSession.builder.appName('pyspark_test').getOrCreate()

Notice that the `tp.StructField` for a variable must appear on the order they appear on the csv, which means that if we change the only the specification order of `carat` and `cut` the dataset wont be read correctly.

In [25]:
my_schema = tp.StructType([
    tp.StructField(name= 'carat',      dataType= tp.DoubleType(),   nullable= True),
    tp.StructField(name= 'cut',        dataType= tp.StringType(),    nullable= True),
    tp.StructField(name= 'color',      dataType= tp.StringType(),   nullable= True),
    tp.StructField(name= 'clarity',    dataType= tp.StringType(),    nullable= True),
    tp.StructField(name= 'depth',      dataType= tp.DoubleType(),    nullable= True),
    tp.StructField(name= 'table',      dataType= tp.DoubleType(),    nullable= True),
    tp.StructField(name= 'price',      dataType= tp.DoubleType(),   nullable= True),
    tp.StructField(name= 'x',          dataType= tp.DoubleType(),   nullable= True),
    tp.StructField(name= 'y',          dataType= tp.DoubleType(),   nullable= True),
    tp.StructField(name= 'z',          dataType= tp.DoubleType(),   nullable= True)
])

In [26]:
df = spark.read.csv(
    'ggplot2_diamonds.csv',
    sep=',',
    header=True,
    schema= my_schema,
    nullValue='NA'
    )

df = df.dropna()

df.show(15)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0|339.0|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|340.0|3.93| 3.9|2.46|
| 0.22|  Premium|    F|    SI1| 60.4| 61.0|342.0|3.88|3.84|2.33|
| 0.31|    Ideal|    J|    SI2| 62.2| 54.0|344.0|4.35|4.37|2.71|
|  0.2|  Premium|    E|    SI2| 60.2| 62.0|345.0|3.79|3.75|2.27|
| 0.32|  Premium|    E|     I1| 60.9| 58.0|345.0|4.38|4.42|2.68|
|  0.3|    Ideal|    I|    SI2| 62.0| 54.0|348.0|4.31|4.34|2.68|
|  0.3|     Good|    J|  

In [27]:
str_cols = ['color','cut','clarity']
float_cols = ['carat','depth','table','price','x','y','z']

In [28]:
idx_stage = StringIndexer(
    inputCols=str_cols,
    outputCols=[f"{x}_idxd" for x in str_cols]
    )

In [29]:
ohe_stage = OneHotEncoder(
    inputCols=[f"{x}_idxd" for x in str_cols], 
    outputCols=[f"{x}_onehot" for x in str_cols])


In [30]:
assemble_stage = VectorAssembler(
    inputCols = [f"{x}_onehot" for x in str_cols] + float_cols,
    outputCol = 'model_features'
    )

In [31]:
[f"{x}_onehot" for x in str_cols] + float_cols

['color_onehot',
 'cut_onehot',
 'clarity_onehot',
 'carat',
 'depth',
 'table',
 'price',
 'x',
 'y',
 'z']

In [32]:
ml_pipeline = Pipeline(stages=[
    idx_stage,
    ohe_stage,
    assemble_stage
])

ml_pipeline.fit(df).transform(df).show(15)



+-----+---------+-----+-------+-----+-----+-----+----+----+----+----------+--------+------------+-------------+-------------+--------------+--------------------+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|color_idxd|cut_idxd|clarity_idxd| color_onehot|   cut_onehot|clarity_onehot|      model_features|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+----------+--------+------------+-------------+-------------+--------------+--------------------+
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|       1.0|     3.0|         3.0|(6,[1],[1.0])|(4,[3],[1.0])| (7,[3],[1.0])|(24,[1,9,13,17,18...|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|       5.0|     1.0|         1.0|(6,[5],[1.0])|(4,[1],[1.0])| (7,[1],[1.0])|(24,[5,7,11,17,18...|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|       6.0|     2.0|         4.0|    (6,[],[])|(4,[2],[1.0])| (7,[4],[1.0])|(24,[8,14,17,18,1...|
| 0.24|Very Good|    I|   VV