In [1]:
from pyspark.ml.regression import LinearRegression
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [2]:
# Create a spark session
spark = SparkSession.builder.appName('LinReg').getOrCreate()

In [3]:
preset = 'file:///root/lab/ws/dsml-learning/spark_python_handson/dataset/'
lines = spark.sparkContext.textFile(preset + 'regression.txt')

In [4]:
data = lines.map(lambda x: x.split(',')) \
            .map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

In [5]:
colNames = ['label', 'features']
df = data.toDF(colNames)

In [6]:
trainTest = df.randomSplit([0.5, 0.5])
trainingDF = trainTest[0]
testDF = trainTest[1]

In [7]:
lir = LinearRegression(maxIter=10, 
                       regParam=0.3, 
                       elasticNetParam=0.6)

In [8]:
model = lir.fit(trainingDF)

In [9]:
fullPredictions = model.transform(testDF).cache()

In [12]:
predictions = fullPredictions.select('prediction').rdd.map(lambda x:x[0])
labels = fullPredictions.select('label').rdd.map(lambda x: x[0])

In [13]:
predictionAndLabel = predictions.zip(labels).collect()

In [14]:
for prediction in predictionAndLabel:
    print(prediction)

(-1.899133771443206, -2.36)
(-1.5804090506777633, -2.27)
(-1.6238715126003238, -2.26)
(-1.5804090506777633, -2.17)
(-1.3703404847187217, -2.12)
(-1.421046690295042, -2.09)
(-1.4717528958713626, -2.07)
(-1.4572654085638423, -2.0)
(-1.3196342791424014, -1.91)
(-1.4282904339488023, -1.87)
(-1.3268780227961614, -1.8)
(-1.2471968426048008, -1.79)
(-1.1964906370284802, -1.77)
(-1.2182218679897605, -1.75)
(-1.1964906370284802, -1.74)
(-1.1820031497209602, -1.66)
(-1.1675156624134402, -1.59)
(-1.1892468933747202, -1.58)
(-0.9936658147231988, -1.48)
(-1.0516157639532793, -1.47)
(-1.015397045684479, -1.46)
(-1.14578443145216, -1.42)
(-0.9502033528006386, -1.4)
(-0.9502033528006386, -1.39)
(-0.9357158654931185, -1.34)
(-1.0661032512607993, -1.33)
(-0.8198159670329574, -1.3)
(-0.8560346853017577, -1.27)
(-0.8415471979942376, -1.26)
(-0.8705221726092779, -1.26)
(-0.9502033528006386, -1.25)
(-0.8560346853017577, -1.23)
(-0.7980847360716774, -1.2)
(-0.8632784289555178, -1.2)
(-0.892253403570558, -1.1