# PySpark Machine Learning - Implement Linear Regression

In [41]:
!pip install pyspark




Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [42]:
from pyspark.sql import SparkSession
from pyspark.ml.stat import Correlation
import pyspark.sql.functions as F

In [43]:
spark = SparkSession.builder.getOrCreate()

In [44]:
df = spark.read.csv("/content/drive/MyDrive/insurance.csv",inferSchema=True,header=True)

In [45]:
df.show()

+---+------+----+--------+------+---------+--------+
|age|   sex| bmi|children|smoker|   region|expenses|
+---+------+----+--------+------+---------+--------+
| 19|female|27.9|       0|   yes|southwest|16884.92|
| 18|  male|33.8|       1|    no|southeast| 1725.55|
| 28|  male|33.0|       3|    no|southeast| 4449.46|
| 33|  male|22.7|       0|    no|northwest|21984.47|
| 32|  male|28.9|       0|    no|northwest| 3866.86|
| 31|female|25.7|       0|    no|southeast| 3756.62|
| 46|female|33.4|       1|    no|southeast| 8240.59|
| 37|female|27.7|       3|    no|northwest| 7281.51|
| 37|  male|29.8|       2|    no|northeast| 6406.41|
| 60|female|25.8|       0|    no|northwest|28923.14|
| 25|  male|26.2|       0|    no|northeast| 2721.32|
| 62|female|26.3|       0|   yes|southeast|27808.73|
| 23|  male|34.4|       0|    no|southwest| 1826.84|
| 56|female|39.8|       0|    no|southeast|11090.72|
| 27|  male|42.1|       0|   yes|southeast|39611.76|
| 19|  male|24.6|       1|    no|southwest| 18

### Performing Exploratory Data Analysis

In [46]:
df.describe().show()

+-------+------------------+------+------------------+-----------------+------+---------+------------------+
|summary|               age|   sex|               bmi|         children|smoker|   region|          expenses|
+-------+------------------+------+------------------+-----------------+------+---------+------------------+
|  count|              1338|  1338|              1338|             1338|  1338|     1338|              1338|
|   mean| 39.20702541106129|  null|30.665470852017993|  1.0949177877429|  null|     null|13270.422414050803|
| stddev|14.049960379216147|  null|  6.09838219000336|1.205492739781914|  null|     null|12110.011239706473|
|    min|                18|female|              16.0|                0|    no|northeast|           1121.87|
|    max|                64|  male|              53.1|                5|   yes|southwest|          63770.43|
+-------+------------------+------+------------------+-----------------+------+---------+------------------+



### Using StringIndexer for dealing with Categorical data

 StringIndexer: It is used to convert a string column into numerical form. It allocates unique values to each of the categories present in the respective column.

In [47]:
from pyspark.ml.feature import StringIndexer

In [48]:
indexer=StringIndexer(inputCols=["sex","smoker","region"],outputCols=['sex_cat','smoker_cat','region_cat'])
indexed=indexer.fit(df).transform(df)

In [49]:
indexed.show()

+---+------+----+--------+------+---------+--------+-------+----------+----------+
|age|   sex| bmi|children|smoker|   region|expenses|sex_cat|smoker_cat|region_cat|
+---+------+----+--------+------+---------+--------+-------+----------+----------+
| 19|female|27.9|       0|   yes|southwest|16884.92|    1.0|       1.0|       2.0|
| 18|  male|33.8|       1|    no|southeast| 1725.55|    0.0|       0.0|       0.0|
| 28|  male|33.0|       3|    no|southeast| 4449.46|    0.0|       0.0|       0.0|
| 33|  male|22.7|       0|    no|northwest|21984.47|    0.0|       0.0|       1.0|
| 32|  male|28.9|       0|    no|northwest| 3866.86|    0.0|       0.0|       1.0|
| 31|female|25.7|       0|    no|southeast| 3756.62|    1.0|       0.0|       0.0|
| 46|female|33.4|       1|    no|southeast| 8240.59|    1.0|       0.0|       0.0|
| 37|female|27.7|       3|    no|northwest| 7281.51|    1.0|       0.0|       1.0|
| 37|  male|29.8|       2|    no|northeast| 6406.41|    0.0|       0.0|       3.0|
| 60

### Feature Engineering

In [50]:
from pyspark.ml.linalg import Vector
from pyspark.ml.feature import VectorAssembler

In [51]:
assembler=VectorAssembler(inputCols=['age','bmi','children','sex_cat','smoker_cat','region_cat'],outputCol='features')

In [52]:
assembler



VectorAssembler_5f3cbbc0b736

In [87]:
assembler.transform(indexed).show()

+---+------+----+--------+------+---------+--------+-------+----------+----------+--------------------+
|age|   sex| bmi|children|smoker|   region|expenses|sex_cat|smoker_cat|region_cat|            features|
+---+------+----+--------+------+---------+--------+-------+----------+----------+--------------------+
| 19|female|27.9|       0|   yes|southwest|16884.92|    1.0|       1.0|       2.0|[19.0,27.9,0.0,1....|
| 18|  male|33.8|       1|    no|southeast| 1725.55|    0.0|       0.0|       0.0|[18.0,33.8,1.0,0....|
| 28|  male|33.0|       3|    no|southeast| 4449.46|    0.0|       0.0|       0.0|[28.0,33.0,3.0,0....|
| 33|  male|22.7|       0|    no|northwest|21984.47|    0.0|       0.0|       1.0|[33.0,22.7,0.0,0....|
| 32|  male|28.9|       0|    no|northwest| 3866.86|    0.0|       0.0|       1.0|[32.0,28.9,0.0,0....|
| 31|female|25.7|       0|    no|southeast| 3756.62|    1.0|       0.0|       0.0|[31.0,25.7,0.0,1....|
| 46|female|33.4|       1|    no|southeast| 8240.59|    1.0|    

In [94]:
selected_cols = assembler.transform(indexed).select("features","expenses")
selected_cols.show(5)


+--------------------+--------+
|            features|expenses|
+--------------------+--------+
|[19.0,27.9,0.0,1....|16884.92|
|[18.0,33.8,1.0,0....| 1725.55|
|[28.0,33.0,3.0,0....| 4449.46|
|[33.0,22.7,0.0,0....|21984.47|
|[32.0,28.9,0.0,0....| 3866.86|
+--------------------+--------+
only showing top 5 rows



### Splitting the Dataset

In [102]:
final_data=assembler.transform(indexed).select("features","expenses")

In [103]:
train_data,test_data=final_data.randomSplit([0.7,0.3])

In [104]:
train_data.describe().show()

+-------+------------------+
|summary|          expenses|
+-------+------------------+
|  count|               931|
|   mean|13120.614607948444|
| stddev|11980.636329352174|
|    min|           1121.87|
|    max|          63770.43|
+-------+------------------+



In [109]:
test_data.describe().show()

+-------+------------------+
|summary|          expenses|
+-------+------------------+
|  count|               407|
|   mean|13613.103169533179|
| stddev|  12409.0168813331|
|    min|           1137.47|
|    max|           55135.4|
+-------+------------------+



### Linear Regression

In [110]:
from pyspark.ml.regression import LinearRegression

In [111]:
#creating an object of class LinearRegression
#object takes features and label as input arguments
ship_lr=LinearRegression(featuresCol='features',labelCol='expenses')

In [112]:
#pass train_data to train_model
trained_ship_model=ship_lr.fit(train_data)

In [113]:
#evaluating model trained for Rsquared error
ship_results=trained_ship_model.evaluate(train_data)

In [114]:
print('Rsquared Error :',ship_results.r2)

Rsquared Error : 0.7492539639895415


In [115]:
print(ship_results.meanSquaredError)

35952336.04952776


### Predictions from the Model

In [117]:
#testing model on unlabeled data
#create unlabeled data from test_data
#testing model on unlabeled data
unlabeled_data=test_data.select('features')
unlabeled_data.show(5)

+--------------------+
|            features|
+--------------------+
|(6,[0,1],[18.0,34...|
|(6,[0,1],[22.0,33...|
|(6,[0,1],[23.0,41...|
|(6,[0,1],[24.0,32...|
|(6,[0,1],[26.0,35...|
+--------------------+
only showing top 5 rows



### PySpark Linear Regression Predict

In [118]:
predictions=trained_ship_model.transform(unlabeled_data)
predictions.show()

+--------------------+-------------------+
|            features|         prediction|
+--------------------+-------------------+
|(6,[0,1],[18.0,34...| 3817.2938701845815|
|(6,[0,1],[22.0,33...|  4580.281017474421|
|(6,[0,1],[23.0,41...|  7403.494146046261|
|(6,[0,1],[24.0,32...|  4483.144347164909|
|(6,[0,1],[26.0,35...|  6045.259189230765|
|(6,[0,1],[27.0,32...|  5422.334561330692|
|(6,[0,1],[29.0,27...|   4367.93740311231|
|(6,[0,1],[33.0,30...|  6088.185038311018|
|(6,[0,1],[34.0,34...|  7571.233483810447|
|(6,[0,1],[42.0,24...|   6512.60446103619|
|(6,[0,1],[48.0,40...| 12826.291816678826|
|(6,[0,1],[52.0,34...| 11834.301402802412|
|(6,[0,1],[59.0,26...|  11047.60149569498|
|(6,[0,1],[62.0,39...| 16071.103124938589|
|[18.0,16.0,0.0,0....|-1481.1075870262976|
|[18.0,21.6,0.0,0....| 24028.256108076323|
|[18.0,23.1,0.0,0....|  784.4089010246862|
|[18.0,23.8,0.0,0....| 1007.7696815367526|
|[18.0,25.1,0.0,1....|  1251.389781482043|
|[18.0,26.1,0.0,0....|  1741.669388933551|
+----------