In [1]:
from __future__ import print_function
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler

In [2]:
if __name__ == "__main__":

    # Create a SparkSession (Note, the config section is only for Windows!)
    spark = SparkSession.builder.appName("DecisionTree").getOrCreate()

In [3]:
    # Load up data as dataframe
    data = spark.read.option("header", "true").option("inferSchema", "true")\
        .csv("file:///C:/SparkCourse/Machine Learning/Dataset/realestate.csv")

In [4]:
    assembler = VectorAssembler().setInputCols(["HouseAge", "DistanceToMRT", \
                               "NumberConvenienceStores"]).setOutputCol("features")
    
    df = assembler.transform(data).select("PriceOfUnitArea", "features")    

In [5]:
    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

In [6]:
    # Now create our decision tree
    dtr = DecisionTreeRegressor().setFeaturesCol("features").setLabelCol("PriceOfUnitArea")

In [7]:
    # Train the model using our training data
    model = dtr.fit(trainingDF)

In [8]:
    # Now see if we can predict values in our test data.
    # Generate predictions using our decision tree model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

In [9]:
    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("PriceOfUnitArea").rdd.map(lambda x: x[0])    

In [10]:
    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

In [11]:
    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
        print(prediction)

(37.47837837837837, 7.6)
(13.900000000000002, 11.2)
(25.335294117647063, 11.6)
(25.335294117647063, 12.8)
(18.700000000000006, 12.8)
(25.335294117647063, 12.9)
(17.4, 13.2)
(18.700000000000006, 13.8)
(25.335294117647063, 14.7)
(13.900000000000002, 15.0)
(25.335294117647063, 15.4)
(22.1, 15.5)
(25.335294117647063, 15.6)
(13.900000000000002, 15.6)
(12.200000000000003, 17.4)
(25.335294117647063, 18.2)
(25.335294117647063, 18.3)
(25.335294117647063, 18.8)
(12.200000000000003, 18.8)
(18.700000000000006, 19.0)
(15.8, 19.2)
(13.900000000000002, 19.2)
(25.335294117647063, 20.5)
(18.700000000000006, 20.7)
(30.231818181818177, 21.3)
(25.335294117647063, 21.8)
(23.921428571428574, 22.1)
(30.231818181818177, 22.3)
(23.921428571428574, 22.3)
(12.200000000000003, 22.6)
(18.700000000000006, 22.8)
(30.231818181818177, 23.1)
(23.921428571428574, 23.1)
(25.335294117647063, 23.2)
(30.231818181818177, 23.5)
(25.335294117647063, 23.6)
(25.335294117647063, 23.7)
(30.231818181818177, 23.9)
(25.33529411764706

In [None]:
    # Stop the session
    spark.stop()