In [1]:
from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [2]:
spark = SparkSession.builder.config("spark.sql.warehouse.dir","file:///C:/temp").appName("LinearRegression").getOrCreate()

In [3]:
inputLines = spark.sparkContext.textFile("regression.txt")
data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

In [4]:
colNames = ["label","features"]
df = data.toDF(colNames)

In [5]:
trainTest = df.randomSplit([0.5,0.5])
trainingDF = trainTest[0]
testDF = trainTest[1]

In [6]:
lir = LinearRegression(maxIter=10, regParam=0.3,elasticNetParam=0.8)

In [7]:
model = lir.fit(trainingDF)


In [8]:
fullPredictions = model.transform(testDF).cache()

In [9]:
predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

In [10]:
predictionsLabel = predictions.zip(labels).collect()

In [11]:
for prediction in predictionsLabel:
    print(prediction)

(-2.633624513936021, -3.74)
(-1.8039741063526527, -2.58)
(-1.6492935218879572, -2.29)
(-1.5367985513681783, -2.27)
(-1.3329014173010794, -2.12)
(-1.3821179669034827, -2.09)
(-1.431334516505886, -2.07)
(-1.3610251599310241, -1.94)
(-1.2836848676986763, -1.91)
(-1.3047776746711348, -1.91)
(-1.3258704816435933, -1.88)
(-1.2133755111238147, -1.79)
(-1.1641589615214114, -1.77)
(-1.023540248371688, -1.67)
(-1.1360352188914669, -1.59)
(-1.1571280258639252, -1.58)
(-1.1782208328363837, -1.53)
(-1.023540248371688, -1.47)
(-0.9883855700842571, -1.46)
(-1.1149424119190083, -1.42)
(-0.9251071491668817, -1.39)
(-0.8758905995644785, -1.37)
(-0.9883855700842571, -1.36)
(-0.7985503073321305, -1.3)
(-0.8266740499620753, -1.3)
(-1.0165093127142018, -1.3)
(-0.8337049856195614, -1.27)
(-0.8477668569345338, -1.26)
(-0.8407359212770477, -1.25)
(-0.9251071491668817, -1.25)
(-0.7633956290446998, -1.24)
(-0.8337049856195614, -1.23)
(-0.7774575003596722, -1.2)
(-0.8407359212770477, -1.2)
(-0.8758905995644785, -

In [None]:
spark.stop()