In [22]:
# wrangling
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, DateType, FloatType
from pyspark.sql.functions import col,when

# modelling
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import StringIndexer

spark=SparkSession.builder.appName('pyspark_test').getOrCreate()

df = spark.read.option('header','true').csv('ggplot2_diamonds.csv',inferSchema=True)

In [2]:
df.printSchema()

root
 |-- carat: string (nullable = true)
 |-- cut: string (nullable = true)
 |-- color: string (nullable = true)
 |-- clarity: string (nullable = true)
 |-- depth: string (nullable = true)
 |-- table: double (nullable = true)
 |-- price: string (nullable = true)
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)
 |-- z: double (nullable = true)



In [5]:

float_cols = ['carat','depth','table','price']

for col in float_cols:
    df = df.withColumn(col,df[col].cast(FloatType()))


In [6]:
str_cols = ['color','cut','clarity']

for col in str_cols:
    df = df.withColumn(col,df[col].cast(StringType()))

df.printSchema()

root
 |-- carat: float (nullable = true)
 |-- cut: string (nullable = true)
 |-- color: string (nullable = true)
 |-- clarity: string (nullable = true)
 |-- depth: float (nullable = true)
 |-- table: float (nullable = true)
 |-- price: float (nullable = true)
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)
 |-- z: double (nullable = true)



In [10]:
df = df.select([when(col(c)=="NA",None).otherwise(col(c)).alias(c) for c in df.columns]).dropna()
df.show(5)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
only showing top 5 rows



In [13]:
df = df.drop(*['x','y','z'])
df.show(5)

+-----+---------+-----+-------+-----+-----+-----+
|carat|      cut|color|clarity|depth|table|price|
+-----+---------+-----+-------+-----+-----+-----+
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0|
+-----+---------+-----+-------+-----+-----+-----+
only showing top 5 rows



## One-hot Encoding

To create one-hot (or dummy) encoding in `pyspark` we actually must before transform the string column to be a numeric column (with `StringIndexer`), so that the `OneHotEncoder` can be used.

In [45]:
indexer=StringIndexer(
    inputCols=str_cols,
    outputCols=[f"{x}_idxd" for x in str_cols]
    # outputCols=[x for x in str_cols]
    )
df_preprocd = indexer.fit(df).transform(df)
df_preprocd = df_preprocd.drop(*str_cols)
df_preprocd.show()

+-----+-----+-----+-----+----------+--------+------------+
|carat|depth|table|price|color_idxd|cut_idxd|clarity_idxd|
+-----+-----+-----+-----+----------+--------+------------+
| 0.23| 56.9| 65.0|327.0|       1.0|     3.0|         3.0|
| 0.29| 62.4| 58.0|334.0|       5.0|     1.0|         1.0|
| 0.24| 62.8| 57.0|336.0|       6.0|     2.0|         4.0|
| 0.24| 62.3| 57.0|336.0|       5.0|     2.0|         5.0|
| 0.23| 59.4| 61.0|338.0|       3.0|     2.0|         3.0|
|  0.3| 64.0| 55.0|339.0|       6.0|     3.0|         0.0|
| 0.23| 62.8| 56.0|340.0|       6.0|     0.0|         3.0|
| 0.22| 60.4| 61.0|342.0|       2.0|     1.0|         0.0|
| 0.31| 62.2| 54.0|344.0|       6.0|     0.0|         2.0|
|  0.2| 60.2| 62.0|345.0|       1.0|     1.0|         2.0|
| 0.32| 60.9| 58.0|345.0|       1.0|     1.0|         7.0|
|  0.3| 62.0| 54.0|348.0|       5.0|     0.0|         2.0|
|  0.3| 63.4| 54.0|351.0|       6.0|     3.0|         0.0|
|  0.3| 63.8| 56.0|351.0|       6.0|     3.0|         0.

In [46]:
df_preprocd = OneHotEncoder(
    inputCols=[f"{x}_idxd" for x in str_cols], 
    outputCols=[f"{x}_onehot" for x in str_cols]).fit(df_preprocd).transform(df_preprocd)

df_preprocd = df_preprocd.drop(*[f"{x}_idxd" for x in str_cols])
df_preprocd.show()


+-----+-----+-----+-----+-------------+-------------+--------------+
|carat|depth|table|price| color_onehot|   cut_onehot|clarity_onehot|
+-----+-----+-----+-----+-------------+-------------+--------------+
| 0.23| 56.9| 65.0|327.0|(6,[1],[1.0])|(4,[3],[1.0])| (7,[3],[1.0])|
| 0.29| 62.4| 58.0|334.0|(6,[5],[1.0])|(4,[1],[1.0])| (7,[1],[1.0])|
| 0.24| 62.8| 57.0|336.0|    (6,[],[])|(4,[2],[1.0])| (7,[4],[1.0])|
| 0.24| 62.3| 57.0|336.0|(6,[5],[1.0])|(4,[2],[1.0])| (7,[5],[1.0])|
| 0.23| 59.4| 61.0|338.0|(6,[3],[1.0])|(4,[2],[1.0])| (7,[3],[1.0])|
|  0.3| 64.0| 55.0|339.0|    (6,[],[])|(4,[3],[1.0])| (7,[0],[1.0])|
| 0.23| 62.8| 56.0|340.0|    (6,[],[])|(4,[0],[1.0])| (7,[3],[1.0])|
| 0.22| 60.4| 61.0|342.0|(6,[2],[1.0])|(4,[1],[1.0])| (7,[0],[1.0])|
| 0.31| 62.2| 54.0|344.0|    (6,[],[])|(4,[0],[1.0])| (7,[2],[1.0])|
|  0.2| 60.2| 62.0|345.0|(6,[1],[1.0])|(4,[1],[1.0])| (7,[2],[1.0])|
| 0.32| 60.9| 58.0|345.0|(6,[1],[1.0])|(4,[1],[1.0])|     (7,[],[])|
|  0.3| 62.0| 54.0|348.0|(6,[5],[1

In [61]:
x_cols = [x for i,x in enumerate(df_preprocd.columns) if i!=3]
y_col = df_preprocd.columns[3]

print(f"{len(x_cols)} input features: ", x_cols)
print(f"{len(y_col)} target feature: ", y_col)

feature_assembler = VectorAssembler(inputCols=x_cols,outputCol="Independent Features")

output = feature_assembler.transform(df_preprocd)

finalized_data = output.select("Independent Features","price")

6 input features:  ['carat', 'depth', 'table', 'color_onehot', 'cut_onehot', 'clarity_onehot']
5 target feature:  price


In [62]:
train_data,test_data = finalized_data.randomSplit([0.75,0.25])

reg_mod = LinearRegression(featuresCol='Independent Features', labelCol='price')
reg_mod = reg_mod.fit(train_data)

21/10/23 18:54:22 WARN Instrumentation: [55d4efb4] regParam is zero, which might cause numerical instability and overfitting.


In [65]:
print("beta_0 = ", reg_mod.intercept)
print()
print("betas = ", reg_mod.coefficients)


beta_0 =  -6767.218762526329

betas =  [8890.69146916395,-21.16052620291968,-25.98069163061147,1829.2136865696657,2123.6974198736366,2029.0240294657042,1365.6891615753918,2340.81558093979,907.6108490391551,844.2172022020659,781.6305891250361,750.4149729154112,590.0332706819676,3553.9951395185285,4197.658289927469,2608.2137356053468,4503.666623574199,4940.714685195995,5038.463333974973,5378.490954739035]


In [67]:
pred_results = reg_mod.evaluate(test_data)

pred_results.predictions.show()



+--------------------+------+-------------------+
|Independent Features| price|         prediction|
+--------------------+------+-------------------+
|(20,[0,1,2],[0.99...|1789.0|-1081.6903173489527|
|(20,[0,1,2],[2.26...|5733.0| 10555.261124186756|
|(20,[0,1,2],[2.72...|6870.0| 14517.395734187776|
|(20,[0,1,2,3],[0....| 701.0|-3418.4707996315865|
|(20,[0,1,2,3],[0....| 727.0|-3346.8769471774294|
|(20,[0,1,2,3],[0....| 727.0|-3327.2443165424925|
|(20,[0,1,2,3],[0....| 956.0| -1554.866226368771|
|(20,[0,1,2,3],[0....|1651.0| -266.5236221006544|
|(20,[0,1,2,3],[1....|2723.0| 1166.1824873980895|
|(20,[0,1,2,3],[1....|2398.0| 1124.4494790351573|
|(20,[0,1,2,3],[1....|2655.0| 2939.1491515412235|
|(20,[0,1,2,3],[1....|4277.0|  3627.511736143704|
|(20,[0,1,2,3],[2....|6346.0| 10153.337700695865|
|(20,[0,1,2,3],[2....|6817.0| 11567.467984743194|
|(20,[0,1,2,3],[2....|7257.0| 12718.353479403188|
|(20,[0,1,2,3,9],[...|1019.0| -1766.044531940901|
|(20,[0,1,2,3,9],[...|6186.0|  6865.450228109838|




In [69]:
print("R2 = ",pred_results.r2)
print("MAE = ",pred_results.meanAbsoluteError)
print("MSE = ",pred_results.meanSquaredError)

R2 =  0.916647262352352
MAE =  810.7700722108757
MSE =  1347424.5779997462
