In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('lr_example').getOrCreate()

In [3]:
from pyspark.ml.regression import LinearRegression

In [4]:
data = spark.read.csv('Ecommerce_Customers.csv', header=True, inferSchema=True)

In [5]:
data.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [6]:
data.show()

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|
|   hduke@hotmail.com|4547 Archer Commo...|       DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262|  392.2049334443264|
|    pallen@yahoo.com|24645 Valerie Uni...|          Bisque|33.000914755642675|11.330278057777512|37.110597442120856|   4.104543202376424| 487.54750486747207|
|riverarebecca@gma...|1414 David Throug...|   

In [7]:
data.head(1)

[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005)]

In [8]:
data.head(1)[0]

Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005)

In [12]:
for item in data.head(2)[1]:
    print(item)

hduke@hotmail.com
4547 Archer CommonDiazchester, CA 06566-8576
DarkGreen
31.92627202636016
11.109460728682564
37.268958868297744
2.66403418213262
392.2049334443264


In [13]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [14]:
data.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [15]:
assembler = VectorAssembler(inputCols=[
    'Time on App',
     'Time on Website',
     'Length of Membership'
], outputCol='features')

In [16]:
output = assembler.transform(data)

In [17]:
output.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)
 |-- features: vector (nullable = true)



In [18]:
output.head(1)

[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005, features=DenseVector([12.6557, 39.5777, 4.0826]))]

In [19]:
final_data = output.select('features', 'Yearly Amount Spent')

In [20]:
final_data.show()

+--------------------+-------------------+
|            features|Yearly Amount Spent|
+--------------------+-------------------+
|[12.6556511491667...|  587.9510539684005|
|[11.1094607286825...|  392.2049334443264|
|[11.3302780577775...| 487.54750486747207|
|[13.7175136651425...|  581.8523440352177|
|[12.7951885510781...|  599.4060920457634|
|[12.0269253397550...|   637.102447915074|
|[11.3663483097105...|  521.5721747578274|
|[12.3519589730029...|  549.9041461052942|
|[13.3862352756764...|  570.2004089636196|
|[11.8141282949721...|  427.1993848953282|
|[13.3389754476621...|  492.6060127179966|
|[11.5847829995352...|  522.3374046069357|
|[10.9612984001540...|  408.6403510726275|
|[12.9592260916093...|  573.4158673313865|
|[13.1487256920565...|  470.4527333009554|
|[12.6366060520001...|  461.7807421962299|
|[11.7338616908573...| 457.84769594494855|
|[12.0131946940144...| 407.70454754954415|
|[14.7153875441565...|  452.3156754800354|
|[13.9895925558252...|   605.061038804892|
+----------

In [21]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [22]:
train_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                341|
|   mean| 498.40531888663855|
| stddev|   81.0241971415452|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+



In [23]:
test_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                159|
|   mean|   501.262926974538|
| stddev|  75.73136974704744|
|    min|   266.086340948469|
|    max|  725.5848140556806|
+-------+-------------------+



In [24]:
lr = LinearRegression(labelCol='Yearly Amount Spent')

In [25]:
lr_model = lr.fit(train_data)

In [26]:
test_results = lr_model.evaluate(test_data)

In [27]:
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|  47.64063280375777|
| 13.106272031798483|
|-28.316936789954013|
| 17.031018659367987|
| 43.589351485832424|
|-20.196706834311613|
|   21.4534554243603|
| 12.138671106160757|
| 43.567632039893454|
| -35.61480556404291|
|  -4.67726999932745|
| 52.832813046834644|
|  -9.27725582639971|
| 15.006372285193606|
| -4.746316078502559|
| 14.637643431420145|
| -34.41949452292954|
| -33.21921405730302|
|  17.35786011505519|
|   9.86815822288736|
+-------------------+
only showing top 20 rows



In [28]:
test_results.rootMeanSquaredError

27.875791947385853

In [29]:
test_results.r2

0.8636538615736016

In [30]:
final_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                500|
|   mean|  499.3140382585909|
| stddev|   79.3147815497068|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+



In [31]:
unlabeled_data = test_data.select('features')

In [32]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[8.50815217603260...|
|[9.95399500605983...|
|[10.0473147350711...|
|[10.1631790600525...|
|[10.3201162550591...|
|[10.4805068326944...|
|[10.5345534994610...|
|[10.6279492261562...|
|[10.6517937834741...|
|[10.7191497406283...|
|[10.7321313403036...|
|[10.8898282830050...|
|[10.9025562270197...|
|[10.9713924308754...|
|[10.9725543142924...|
|[10.9828055346628...|
|[11.0104821324000...|
|[11.0523236533066...|
|[11.1094563331684...|
|[11.1213660605094...|
+--------------------+
only showing top 20 rows



In [33]:
predictions = lr_model.transform(unlabeled_data)

In [34]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[8.50815217603260...| 254.5489150058942|
|[9.95399500605983...| 397.4966719174166|
|[10.0473147350711...| 421.3091927039849|
|[10.1631790600525...| 504.2097615764269|
|[10.3201162550591...|341.03722008357295|
|[10.4805068326944...| 499.9286444796161|
|[10.5345534994610...|357.02011102354084|
|[10.6279492261562...| 410.2300655006659|
|[10.6517937834741...| 424.8781051875129|
|[10.7191497406283...|413.94571247084673|
|[10.7321313403036...|482.93939639616326|
|[10.8898282830050...| 617.1543274548683|
|[10.9025562270197...| 533.9152204405045|
|[10.9713924308754...|463.17668742668604|
|[10.9725543142924...|419.68137673032186|
|[10.9828055346628...|437.09021988616666|
|[11.0104821324000...|459.18213003284063|
|[11.0523236533066...| 534.9684473656953|
|[11.1094563331684...| 485.0519251878501|
|[11.1213660605094...|500.67126347674093|
+--------------------+------------