In [1]:
from pyspark import SparkContext
sc=SparkContext(master='local')

from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('Linear regression example').\
                config('spark.some.config.option','some-value').getOrCreate()

## Import the data

In [2]:
ad_df=spark.read.csv('Advertising.csv',inferSchema=True,header=True)
ad_df.show(5)

+-----+-----+---------+-----+
|   TV|Radio|Newspaper|Sales|
+-----+-----+---------+-----+
|230.1| 37.8|     69.2| 22.1|
| 44.5| 39.3|     45.1| 10.4|
| 17.2| 45.9|     69.3|  9.3|
|151.5| 41.3|     58.5| 18.5|
|180.8| 10.8|     58.4| 12.9|
+-----+-----+---------+-----+
only showing top 5 rows



## features columns to vectors

In [4]:
from pyspark.ml.linalg import Vectors
ad_df2=ad_df.rdd.map(lambda x:[Vectors.dense(x[:-1]),x[-1]]).toDF(['features','label'])
ad_df2.show(5)

  return f(*args, **kwds)
  return f(*args, **kwds)


+-----------------+-----+
|         features|label|
+-----------------+-----+
|[230.1,37.8,69.2]| 22.1|
| [44.5,39.3,45.1]| 10.4|
| [17.2,45.9,69.3]|  9.3|
|[151.5,41.3,58.5]| 18.5|
|[180.8,10.8,58.4]| 12.9|
+-----------------+-----+
only showing top 5 rows



## Build the regression model

In [5]:
from pyspark.ml.regression import LinearRegression
lr=LinearRegression(featuresCol='features',labelCol='label')

## Fit the model

In [6]:
lr_reg=lr.fit(ad_df2)

## prediction

In [7]:
lr_pred=lr_reg.transform(ad_df2)
lr_pred.show(5)

+-----------------+-----+------------------+
|         features|label|        prediction|
+-----------------+-----+------------------+
|[230.1,37.8,69.2]| 22.1| 20.52397440971517|
| [44.5,39.3,45.1]| 10.4|12.337854820894362|
| [17.2,45.9,69.3]|  9.3|12.307670779994238|
|[151.5,41.3,58.5]| 18.5| 17.59782951168913|
|[180.8,10.8,58.4]| 12.9|13.188671856831299|
+-----------------+-----+------------------+
only showing top 5 rows



## Model evaluation

In [9]:
from pyspark.ml.evaluation import RegressionEvaluator
evaluator=RegressionEvaluator(predictionCol='prediction',labelCol='label')
evaluator.setMetricName('r2').evaluate(lr_pred)

0.897210638178952