In [0]:
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler

spark = SparkSession.builder.appName("lr-example").getOrCreate()
full_data = spark.read.options(header=True,inferSchema=True).csv("/FileStore/tables/Ecommerce_Customers.csv")
full_data.printSchema()
full_data.show()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|
|

In [0]:
full_data.columns

Out[2]: ['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [0]:
assembler = VectorAssembler(inputCols=[ 'Avg Session Length', 'Time on App','Time on Website','Length of Membership'],outputCol='features')
output = assembler.transform(full_data)
output.show()

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+--------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|            features|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+--------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|[34.4972677251122...|
|   hduke@hotmail.com|4547 Archer Commo...|       DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262|  392.2049334443264|[31.9262720263601...|
|    pallen@yahoo.com|24645 Valerie Uni...|          Bisque|33.000914755642675|11.330278057777512|37

In [0]:
final_data = output.select('features','Yearly Amount Spent')
final_data.show()

+--------------------+-------------------+
|            features|Yearly Amount Spent|
+--------------------+-------------------+
|[34.4972677251122...|  587.9510539684005|
|[31.9262720263601...|  392.2049334443264|
|[33.0009147556426...| 487.54750486747207|
|[34.3055566297555...|  581.8523440352177|
|[33.3306725236463...|  599.4060920457634|
|[33.8710378793419...|   637.102447915074|
|[32.0215955013870...|  521.5721747578274|
|[32.7391429383803...|  549.9041461052942|
|[33.9877728956856...|  570.2004089636196|
|[31.9365486184489...|  427.1993848953282|
|[33.9925727749537...|  492.6060127179966|
|[33.8793608248049...|  522.3374046069357|
|[29.5324289670579...|  408.6403510726275|
|[33.1903340437226...|  573.4158673313865|
|[32.3879758531538...|  470.4527333009554|
|[30.7377203726281...|  461.7807421962299|
|[32.1253868972878...| 457.84769594494855|
|[32.3388993230671...| 407.70454754954415|
|[32.1878120459321...|  452.3156754800354|
|[32.6178560628234...|   605.061038804892|
+----------

In [0]:
train_data,test_data = final_data.randomSplit([0.7,0.3])
train_data.describe().show()
test_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                352|
|   mean| 498.42183311705395|
| stddev|  81.65203160194238|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                148|
|   mean| 501.43603967629986|
| stddev|  73.68086056322619|
|    min| 308.52774655803336|
|    max|  744.2218671047146|
+-------+-------------------+



In [0]:
lr = LinearRegression(featuresCol='features',labelCol='Yearly Amount Spent')
lr_model = lr.fit(train_data)
test_results = lr_model.evaluate(test_data)
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|-1.0191709969624867|
|-22.356992564426832|
|  -8.31763536138692|
|  2.205034052451879|
| 17.926546753017703|
|   3.24349009258151|
| -9.784599754699855|
| -2.756855310148808|
|  16.53748594495943|
|  -5.05283618446316|
| -6.737249477775947|
|-2.8650317390760733|
| -9.915936823037953|
|0.38759356652644783|
|-2.3980813589915897|
| -6.524071500862135|
|-2.3857249905226467|
|  4.858481852922864|
| -5.130325059326765|
|-0.5198103952937458|
+-------------------+
only showing top 20 rows



In [0]:
test_results.rootMeanSquaredError

Out[7]: 9.541057888909586

In [0]:
test_results.meanAbsoluteError

Out[8]: 7.476151560189321

In [0]:
test_results.meanSquaredError

Out[9]: 91.03178563952386

In [0]:
test_results.r2

Out[10]: 0.9831178380090195

In [0]:
unlabeled_data = test_data.select('features')
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[30.5743636841713...|
|[31.1239743499119...|
|[31.1280900496166...|
|[31.3091926408918...|
|[31.3123495994443...|
|[31.3584771924370...|
|[31.5261978982398...|
|[31.5761319713222...|
|[31.6098395733896...|
|[31.6253601348306...|
|[31.7242025238451...|
|[31.8186165667690...|
|[31.8279790554652...|
|[31.8293464559211...|
|[31.8627411090001...|
|[31.8745516945853...|
|[31.9120759292006...|
|[31.9480174211613...|
|[31.9673209478824...|
|[32.0047530203648...|
+--------------------+
only showing top 20 rows



In [0]:
predictions = lr_model.transform(unlabeled_data)
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[30.5743636841713...|443.08358475502814|
|[31.1239743499119...| 509.3040464041926|
|[31.1280900496166...| 565.5703221084416|
|[31.3091926408918...|430.51568378748175|
|[31.3123495994443...| 445.6648712749229|
|[31.3584771924370...| 491.9324603568939|
|[31.5261978982398...|418.87912594703766|
|[31.5761319713222...| 543.9834392994771|
|[31.6098395733896...|428.00806370614873|
|[31.6253601348306...|381.38973694138735|
|[31.7242025238451...|510.12513676573644|
|[31.8186165667690...| 449.2837051092117|
|[31.8279790554652...| 449.9186843699795|
|[31.8293464559211...|384.76474442144854|
|[31.8627411090001...| 558.6962225330383|
|[31.8745516945853...|398.80931574712963|
|[31.9120759292006...|389.92044129623036|
|[31.9480174211613...|457.06239503997494|
|[31.9673209478824...|  450.880166298979|
|[32.0047530203648...| 464.2657915159232|
+--------------------+------------