In [3]:
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression

In [2]:
spark = SparkSession.builder.master("local").appName("lr_example").getOrCreate()

In [8]:
data = spark.read.csv("Ecommerce-Customers.csv",inferSchema=True,header=True)

In [9]:
data.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [16]:
data.show(n=5)

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|
|   hduke@hotmail.com|4547 Archer Commo...|       DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262|  392.2049334443264|
|    pallen@yahoo.com|24645 Valerie Uni...|          Bisque|33.000914755642675|11.330278057777512|37.110597442120856|   4.104543202376424| 487.54750486747207|
|riverarebecca@gma...|1414 David Throug...|   

# Target Variable: Yearly Amount Spent

In [17]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [18]:
data.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

Taking the multiple input columns and creates a single feature output column

In [21]:
assembler = VectorAssembler(inputCols=['Avg Session Length','Time on App','Time on Website','Length of Membership'],outputCol='features')

In [23]:
output = assembler.transform(data)

In [24]:
output.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)
 |-- features: vector (nullable = true)



In [27]:
output.head(1)

[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005, features=DenseVector([34.4973, 12.6557, 39.5777, 4.0826]))]

In [29]:
final_data = output.select('features','Yearly Amount Spent')

In [30]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [32]:
train_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                343|
|   mean| 499.81504301827385|
| stddev|  81.18263527737052|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+



In [33]:
test_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                157|
|   mean| 498.21948645877353|
| stddev|  75.31387872226645|
|    min|  304.1355915788555|
|    max|  725.5848140556806|
+-------+-------------------+



In [34]:
lr = LinearRegression(featuresCol='features',labelCol='Yearly Amount Spent',predictionCol='prediction')

In [35]:
linear_regression_model = lr.fit(train_data)

In [36]:
test_results = linear_regression_model.evaluate(test_data)

In [37]:
test_results.rootMeanSquaredError

10.141234742860156

In [38]:
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|   9.97190897458222|
|-11.636240287200394|
| 10.414474784258914|
| -4.278582131646033|
| -6.408909399215929|
|-13.198531559521541|
| -8.097009806510869|
|   9.94596637678751|
| 2.8899797051789164|
| 3.5708486449121892|
| -8.023863637019417|
| -18.08813310440496|
|-14.810357668050983|
|-2.4751771245333885|
|-0.9830261075259159|
| -18.09283641651183|
|  -9.60913142216981|
|  8.221044913963908|
| -5.796838945952402|
| -6.424100424530366|
+-------------------+
only showing top 20 rows



In [39]:
test_results.r2

0.9817523603539724

In [40]:
test_results.meanSquaredError

102.8446421097939

In [41]:
test_results.meanAbsoluteError

8.185244277730844