### Simple Leanear Regression

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression

In [2]:
spark = SparkSession.builder.appName('Linear Regression').getOrCreate()

In [6]:
ecomm_cust = spark.read.csv('./Ecommerce_Customers.csv',inferSchema=True, header=True)

In [7]:
ecomm_cust.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [11]:
ecomm_cust.show(5, truncate=False)


+-----------------------------+--------------------------------------------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|Email                        |Address                                                 |Avatar          |Avg Session Length|Time on App       |Time on Website   |Length of Membership|Yearly Amount Spent|
+-----------------------------+--------------------------------------------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@fernandez.com    |835 Frank TunnelWrightmouth, MI 82180-9605              |Violet          |34.49726772511229 |12.65565114916675 |39.57766801952616 |4.0826206329529615  |587.9510539684005  |
|hduke@hotmail.com            |4547 Archer CommonDiazchester, CA 06566-8576            |DarkGreen       |31.92627202636016 |11.109460728682564|37.268958868297744|2.66403418213262    |3

In [16]:
for x in ecomm_cust.head(2)[1]:
    print(x)

hduke@hotmail.com
4547 Archer CommonDiazchester, CA 06566-8576
DarkGreen
31.92627202636016
11.109460728682564
37.268958868297744
2.66403418213262
392.2049334443264


In [23]:
from pyspark.ml.linalg import Vector
from pyspark.ml.feature import VectorAssembler

In [25]:
ecomm_cust.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [31]:
assembler = VectorAssembler(
    inputCols=['Avg Session Length','Time on App','Time on Website','Length of Membership'],
    outputCol='features')

output = assembler.transform(ecomm_cust)
output.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)
 |-- features: vector (nullable = true)



In [32]:
output.head()

Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005, features=DenseVector([34.4973, 12.6557, 39.5777, 4.0826]))

In [36]:
data = output.select("Yearly Amount Spent","features")
data.show(5, truncate=False)

+-------------------+----------------------------------------------------------------------------+
|Yearly Amount Spent|features                                                                    |
+-------------------+----------------------------------------------------------------------------+
|587.9510539684005  |[34.49726772511229,12.65565114916675,39.57766801952616,4.0826206329529615]  |
|392.2049334443264  |[31.92627202636016,11.109460728682564,37.268958868297744,2.66403418213262]  |
|487.54750486747207 |[33.000914755642675,11.330278057777512,37.110597442120856,4.104543202376424]|
|581.8523440352177  |[34.30555662975554,13.717513665142507,36.72128267790313,3.120178782748092]  |
|599.4060920457634  |[33.33067252364639,12.795188551078114,37.53665330059473,4.446308318351434]  |
+-------------------+----------------------------------------------------------------------------+
only showing top 5 rows



In [37]:
train, test = data.randomSplit([0.7,0.3], seed=42)

In [43]:
train.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                374|
|   mean| 502.75221548670083|
| stddev|  77.01978422136051|
|    min| 256.67058229005585|
|    max|  765.5184619388373|
+-------+-------------------+



In [44]:
test.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                126|
|   mean| 489.10865505769056|
| stddev|  85.27268740986126|
|    min|  275.9184206503857|
|    max|  712.3963268096637|
+-------+-------------------+



In [50]:
lr = LinearRegression(featuresCol= 'features', labelCol= 'Yearly Amount Spent')
lr_model = lr.fit(train)
test_result = lr_model.evaluate(test)
print('R2 : ',test_result.r2)

R2 :  0.9875958599044449


In [52]:
test_result.meanSquaredError

89.48001143304037

In [53]:
test_result.meanAbsoluteError

7.560255569514905

In [55]:
unlabeled_data = test.select("features")
unlabeled_data.show(5)

+--------------------+
|            features|
+--------------------+
|[31.5171218025062...|
|[33.6666156834513...|
|[32.9048536673539...|
|[30.3931845423455...|
|[31.7216523605090...|
+--------------------+
only showing top 5 rows



In [60]:
prediction = lr_model.transform(unlabeled_data)
prediction.show(5, truncate=False)

+-----------------------------------------------------------------------------+-----------------+
|features                                                                     |prediction       |
+-----------------------------------------------------------------------------+-----------------+
|[31.51712180250623,10.745188554182882,38.79123468689964,1.4288238768282668]  |280.4980687766465|
|[33.66661568345138,10.9857637851215,36.35250276938114,0.9364975973183264]    |313.4568476590914|
|[32.904853667353976,12.556107616938169,37.805509432449185,0.2699010899842742]|314.5866834999131|
|[30.3931845423455,11.80298577760313,36.315763151803424,2.0838141920346707]   |332.4068022030717|
|[31.721652360509037,11.75502370305383,36.7657223578584,1.8473704233395083]   |350.0240759102519|
+-----------------------------------------------------------------------------+-----------------+
only showing top 5 rows

