# Spark Linear Regression

In [1]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import corr
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [2]:
spark = SparkSession.builder.appName('lrex').getOrCreate()

In [3]:
CPATH = "/home/bm/spark/Python-and-Spark-for-Big-Data-master/Spark_for_Machine_Learning/"

### Data

In [4]:
all_data = spark.read.format('libsvm').load(os.path.join(CPATH,'Linear_Regression/sample_linear_regression_data.txt'))
ec = spark.read.csv(os.path.join(CPATH,'Linear_Regression/Ecommerce_Customers.csv'),inferSchema=True,header=True)
csi = spark.read.csv(os.path.join(CPATH,'Linear_Regression/cruise_ship_info.csv'),inferSchema=True,header=True)

### Linear Regression

In [5]:
training = all_data
training.show()

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
| -9.490009878824548|(10,[0,1,2,3,4,5,...|
| 0.2577820163584905|(10,[0,1,2,3,4,5,...|
| -4.438869807456516|(10,[0,1,2,3,4,5,...|
|-19.782762789614537|(10,[0,1,2,3,4,5,...|
| -7.966593841555266|(10,[0,1,2,3,4,5,...|
| -7.896274316726144|(10,[0,1,2,3,4,5,...|
| -8.464803554195287|(10,[0,1,2,3,4,5,...|
| 2.1214592666251364|(10,[0,1,2,3,4,5,...|
| 1.0720117616524107|(10,[0,1,2,3,4,5,...|
|-13.772441561702871|(10,[0,1,2,3,4,5,...|
| -5.082010756207233|(10,[0,1,2,3,4,5,...|
|  7.887786536531237|(10,[0,1,2,3,4,5,...|
| 14.323146365332388|(10,[0,1,2,3,4,5,...|
|-20.057482615789212|(10,[0,1,2,3,4,5,...|
|-0.8995693247765151|(10,[0,1,2,3,4,5,...|
| -19.16829262296376|(10,[0,1,2,3,4,5,...|
|  5.601801561245534|(10,[0,1,2,3,4,5,...|
|-3.2256352187273354|(10,[0,1,2,3,4,5,...|
| 1.5299675726687754|(10,[0,1,2,3,4,5,...|
| -0.250102447941961|(10,[0,1,2,3,4,5,...|
+----------

In [6]:
lr = LinearRegression(featuresCol='features',labelCol='label',predictionCol='prediction')

In [7]:
lrModel = lr.fit(training)

In [8]:
lrModel.coefficients

DenseVector([0.0073, 0.8314, -0.8095, 2.4412, 0.5192, 1.1535, -0.2989, -0.5129, -0.6197, 0.6956])

In [9]:
lrModel.intercept

0.14228558260358093

In [10]:
training_summary = lrModel.summary

In [11]:
training_summary.r2

0.027839179518600154

In [12]:
training_summary.rootMeanSquaredError

10.16309157133015

In [13]:
train_data, test_data = all_data.randomSplit([0.7,0.3])

In [14]:
train_data.describe().show()

+-------+-------------------+
|summary|              label|
+-------+-------------------+
|  count|                327|
|   mean|0.16112484121889056|
| stddev|  10.54650843103701|
|    min|-28.046018037776633|
|    max|  27.78383192005107|
+-------+-------------------+



In [15]:
test_data.describe().show()

+-------+-------------------+
|summary|              label|
+-------+-------------------+
|  count|                174|
|   mean|0.43685906230524896|
| stddev|  9.901284145674824|
|    min|-28.571478869743427|
|    max| 27.111027963108548|
+-------+-------------------+



In [16]:
correct_model = lr.fit(train_data)

In [17]:
test_results = correct_model.evaluate(test_data)

In [18]:
test_results.rootMeanSquaredError

9.800097054076558

In [19]:
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
| -27.74833432435899|
|-20.313572822353883|
| -19.64649043177113|
|-15.685266633909258|
| -16.89003122383911|
|-14.785619493382606|
|-15.779247152551413|
|-14.982784741469787|
|-14.868917615799091|
|-14.757326783207414|
|-12.962866007432646|
| -13.83377980570243|
|-17.779290153536135|
|-13.146578784207103|
|-12.743250871723804|
|-12.037982415304931|
|-12.021134647596236|
|-14.621577217353348|
| -10.50760152299429|
| -9.597646487411263|
+-------------------+
only showing top 20 rows



In [20]:
unlabeled_data = test_data.select('features')

In [21]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
+--------------------+
only showing top 20 rows



In [22]:
predictions = correct_model.transform(unlabeled_data)

In [23]:
predictions.show()

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...| -0.8231445453844373|
|(10,[0,1,2,3,4,5,...| 0.25609020656467013|
|(10,[0,1,2,3,4,5,...|-0.02082818360058864|
|(10,[0,1,2,3,4,5,...| -2.1183595547552585|
|(10,[0,1,2,3,4,5,...|-0.17536840203690734|
|(10,[0,1,2,3,4,5,...| -2.2408727708269414|
|(10,[0,1,2,3,4,5,...| -0.3721021987256986|
|(10,[0,1,2,3,4,5,...| -0.9687278243247864|
|(10,[0,1,2,3,4,5,...| -0.5684671776321263|
|(10,[0,1,2,3,4,5,...| -0.6185309401048823|
|(10,[0,1,2,3,4,5,...| -2.3481145819836424|
|(10,[0,1,2,3,4,5,...|-0.14235112545027118|
|(10,[0,1,2,3,4,5,...|   4.006848591833263|
|(10,[0,1,2,3,4,5,...|  -0.274015991683654|
|(10,[0,1,2,3,4,5,...| -0.4100847346417269|
|(10,[0,1,2,3,4,5,...| -0.5205933735512572|
|(10,[0,1,2,3,4,5,...| -0.4581455638552611|
|(10,[0,1,2,3,4,5,...|  2.7945042209607758|
|(10,[0,1,2,3,4,5,...| -0.4383181347886419|
|(10,[0,1,2,3,4,5,...| -0.891510

### ECommerce Data

In [24]:
ec.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [25]:
ec.show()

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|
|   hduke@hotmail.com|4547 Archer Commo...|       DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262|  392.2049334443264|
|    pallen@yahoo.com|24645 Valerie Uni...|          Bisque|33.000914755642675|11.330278057777512|37.110597442120856|   4.104543202376424| 487.54750486747207|
|riverarebecca@gma...|1414 David Throug...|   

In [26]:
ec.head(1)

[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005)]

In [27]:
for item in ec.head(1)[0]:
    print(item)

mstephenson@fernandez.com
835 Frank TunnelWrightmouth, MI 82180-9605
Violet
34.49726772511229
12.65565114916675
39.57766801952616
4.0826206329529615
587.9510539684005


In [28]:
ec.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [29]:
assembler = VectorAssembler(inputCols=['Avg Session Length','Time on App',
                                       'Time on Website', 'Length of Membership',
                                       'Yearly Amount Spent'],
                            outputCol='features')

In [30]:
output = assembler.transform(ec)

In [31]:
output.select('features').show()

+--------------------+
|            features|
+--------------------+
|[34.4972677251122...|
|[31.9262720263601...|
|[33.0009147556426...|
|[34.3055566297555...|
|[33.3306725236463...|
|[33.8710378793419...|
|[32.0215955013870...|
|[32.7391429383803...|
|[33.9877728956856...|
|[31.9365486184489...|
|[33.9925727749537...|
|[33.8793608248049...|
|[29.5324289670579...|
|[33.1903340437226...|
|[32.3879758531538...|
|[30.7377203726281...|
|[32.1253868972878...|
|[32.3388993230671...|
|[32.1878120459321...|
|[32.6178560628234...|
+--------------------+
only showing top 20 rows



In [32]:
output.head(1)

[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005, features=DenseVector([34.4973, 12.6557, 39.5777, 4.0826, 587.9511]))]

In [33]:
final_data = output.select('features','Yearly Amount Spent')

In [34]:
final_data.show()

+--------------------+-------------------+
|            features|Yearly Amount Spent|
+--------------------+-------------------+
|[34.4972677251122...|  587.9510539684005|
|[31.9262720263601...|  392.2049334443264|
|[33.0009147556426...| 487.54750486747207|
|[34.3055566297555...|  581.8523440352177|
|[33.3306725236463...|  599.4060920457634|
|[33.8710378793419...|   637.102447915074|
|[32.0215955013870...|  521.5721747578274|
|[32.7391429383803...|  549.9041461052942|
|[33.9877728956856...|  570.2004089636196|
|[31.9365486184489...|  427.1993848953282|
|[33.9925727749537...|  492.6060127179966|
|[33.8793608248049...|  522.3374046069357|
|[29.5324289670579...|  408.6403510726275|
|[33.1903340437226...|  573.4158673313865|
|[32.3879758531538...|  470.4527333009554|
|[30.7377203726281...|  461.7807421962299|
|[32.1253868972878...| 457.84769594494855|
|[32.3388993230671...| 407.70454754954415|
|[32.1878120459321...|  452.3156754800354|
|[32.6178560628234...|   605.061038804892|
+----------

In [35]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [36]:
train_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                354|
|   mean| 497.66998606889604|
| stddev|  77.84472800797168|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+



In [37]:
test_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                146|
|   mean| 503.30030178702884|
| stddev|   82.9099905118163|
|    min|  282.4712457199145|
|    max|  725.5848140556806|
+-------+-------------------+



In [38]:
lr = LinearRegression(labelCol='Yearly Amount Spent')

In [39]:
lr_model = lr.fit(train_data)

In [40]:
test_results = lr_model.evaluate(test_data)

In [41]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
|-4.54747350886464...|
|4.547473508864641...|
|2.273736754432320...|
|1.136868377216160...|
|-1.70530256582424...|
|8.526512829121202...|
|-7.95807864051312...|
|1.705302565824240...|
|2.273736754432320...|
|-2.84217094304040...|
|1.080024958355352...|
|3.410605131648481...|
|2.273736754432320...|
|1.705302565824240...|
|-3.41060513164848...|
|6.821210263296962...|
|-4.54747350886464...|
|5.684341886080801...|
|5.115907697472721...|
|3.410605131648481...|
+--------------------+
only showing top 20 rows



In [42]:
test_results.rootMeanSquaredError

4.3390232705644116e-13

In [43]:
test_results.r2

1.0

In [44]:
final_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                500|
|   mean|  499.3140382585909|
| stddev|   79.3147815497068|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+



In [45]:
unlabeled_data = test_data.select('features')

In [46]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[29.5324289670579...|
|[30.3931845423455...|
|[30.4925366965402...|
|[30.8794843441274...|
|[31.0472221394875...|
|[31.1239743499119...|
|[31.3123495994443...|
|[31.5147378578019...|
|[31.5257524169682...|
|[31.6548096756927...|
|[31.6739155032749...|
|[31.7207699002873...|
|[31.7242025238451...|
|[31.8124825597242...|
|[31.8512531286083...|
|[31.9048571310136...|
|[31.9096268275227...|
|[31.9120759292006...|
|[31.9365486184489...|
|[31.9453957483445...|
+--------------------+
only showing top 20 rows



In [47]:
predictions = lr_model.transform(unlabeled_data)

In [48]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[29.5324289670579...|408.64035107262794|
|[30.3931845423455...|319.92886980319315|
|[30.4925366965402...| 282.4712457199143|
|[30.8794843441274...|490.20659998485456|
|[31.0472221394875...|392.49739918902156|
|[31.1239743499119...| 486.9470538397649|
|[31.3123495994443...| 463.5914180279414|
|[31.5147378578019...|489.81248799646124|
|[31.5257524169682...|443.96562680988166|
|[31.6548096756927...| 475.2634237275488|
|[31.6739155032749...| 475.7250679098801|
|[31.7207699002873...| 538.7749334780226|
|[31.7242025238451...|503.38788728796027|
|[31.8124825597242...|392.81034498379705|
|[31.8512531286083...|472.99224666679874|
|[31.9048571310136...|473.94985742281546|
|[31.9096268275227...| 563.4460356732396|
|[31.9120759292006...|387.53471630570766|
|[31.9365486184489...| 427.1993848953277|
|[31.9453957483445...| 657.0199239376516|
+--------------------+------------

### Consulting Project

In [49]:
csi.describe().show()

+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|
|   mean| Infinity|       null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|
| stddev|      NaN|       null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|
|    min|Adventure|    Azamara|   

In [50]:
csi.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [51]:
for ship in csi.head(5):
    print(f"{ship}\n")

Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55)

Row(Ship_name='Quest', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55)

Row(Ship_name='Celebration', Cruise_line='Carnival', Age=26, Tonnage=47.262, passengers=14.86, length=7.22, cabins=7.43, passenger_density=31.8, crew=6.7)

Row(Ship_name='Conquest', Cruise_line='Carnival', Age=11, Tonnage=110.0, passengers=29.74, length=9.53, cabins=14.88, passenger_density=36.99, crew=19.1)

Row(Ship_name='Destiny', Cruise_line='Carnival', Age=17, Tonnage=101.353, passengers=26.42, length=8.92, cabins=13.21, passenger_density=38.36, crew=10.0)



In [52]:
csi.groupBy("Cruise_line").count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [53]:
indexer = StringIndexer(inputCol="Cruise_line",outputCol="cruise_cat")
indexed = indexer.fit(csi).transform(csi)
indexed.head(1)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0)]

In [54]:
indexed.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'cruise_cat']

In [55]:
assembler = VectorAssembler(inputCols=[ 'Tonnage','passengers','length','cabins','passenger_density','cruise_cat'],
                                      outputCol="features")

In [56]:
output = assembler.transform(indexed)

In [57]:
output.select("features","crew").show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[30.2769999999999...|3.55|
|[30.2769999999999...|3.55|
|[47.262,14.86,7.2...| 6.7|
|[110.0,29.74,9.53...|19.1|
|[101.353,26.42,8....|10.0|
|[70.367,20.52,8.5...| 9.2|
|[70.367,20.52,8.5...| 9.2|
|[70.367,20.56,8.5...| 9.2|
|[70.367,20.52,8.5...| 9.2|
|[110.238999999999...|11.5|
|[110.0,29.74,9.51...|11.6|
|[46.052,14.52,7.2...| 6.6|
|[70.367,20.52,8.5...| 9.2|
|[70.367,20.52,8.5...| 9.2|
|[86.0,21.24,9.63,...| 9.3|
|[110.0,29.74,9.51...|11.6|
|[88.5,21.24,9.63,...|10.3|
|[70.367,20.52,8.5...| 9.2|
|[88.5,21.24,9.63,...| 9.3|
|[70.367,20.52,8.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [58]:
final_data = output.select("features","crew")

In [59]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [60]:
train_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               116|
|   mean| 7.524741379310345|
| stddev|3.4590610591307294|
|    min|              0.59|
|    max|              19.1|
+-------+------------------+



In [61]:
test_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                42|
|   mean| 8.538333333333336|
| stddev|3.5593140724958037|
|    min|               1.6|
|    max|              21.0|
+-------+------------------+



In [62]:
ship_lr = LinearRegression(labelCol="crew")

In [63]:
trained_ship_model = ship_lr.fit(train_data)

In [64]:
ship_results = trained_ship_model.evaluate(test_data)

In [65]:
ship_results.rootMeanSquaredError

0.7008300787007691

In [66]:
train_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               116|
|   mean| 7.524741379310345|
| stddev|3.4590610591307294|
|    min|              0.59|
|    max|              19.1|
+-------+------------------+



In [67]:
ship_results.r2

0.9602846607972416

In [68]:
ship_results.meanSquaredError

0.4911627992117262

In [69]:
ship_results.meanAbsoluteError

0.5892785215613983

In [70]:
csi.describe().show()

+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|
|   mean| Infinity|       null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|
| stddev|      NaN|       null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|
|    min|Adventure|    Azamara|   

In [71]:
csi.select(corr("crew","passengers")).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [72]:
csi.select(corr("crew","cabins")).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

