In [31]:
import csv
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import rand

#Create the spark session
spark = SparkSession.builder.appName("SparkMLib").getOrCreate()

#Set path and read csv
path = f"gs://dataproc-staging-us-central1-359639680738-vfiiktvj/data/2019-01-h1.csv"
df = spark.read.csv(path, header=True, inferSchema=True)

#Select columns
columns = ["passenger_count", "PULocationID", "DOLocationID", "total_amount"]
taxiDF = df.select(columns)
taxiDF.show(10)

#Split the data into training and test
trainDF, testDF = taxiDF.randomSplit([0.8, 0.2], seed=42)

print(f"There are {trainDF.count()} rows in training and {testDF.count()} in test")

#Set features, create decision tree, set pipeline
vecAssembler = VectorAssembler(outputCol="features")
vecAssembler.setInputCols(["passenger_count", "PULocationID", "DOLocationID"])
dt = DecisionTreeRegressor(featuresCol="features", labelCol="total_amount", maxBins=270)
pipeline = Pipeline(stages=[vecAssembler, dt])

#Train model
model = pipeline.fit(trainDF)

#Make predictions with model
predDF = model.transform(testDF)
predDF.select("passenger_count", "PULocationID", "DOLocationID", "prediction").show(10)

#Evaluate
evaluator = RegressionEvaluator(labelCol="total_amount", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predDF)

print(f"Root Mean Squared Error: {rmse}")


                                                                                

+---------------+------------+------------+------------+
|passenger_count|PULocationID|DOLocationID|total_amount|
+---------------+------------+------------+------------+
|            1.0|       151.0|       239.0|        9.95|
|            1.0|       239.0|       246.0|        16.3|
|            3.0|       236.0|       236.0|         5.8|
|            5.0|       193.0|       193.0|        7.55|
|            5.0|       193.0|       193.0|       55.55|
|            5.0|       193.0|       193.0|       13.31|
|            5.0|       193.0|       193.0|       55.55|
|            1.0|       163.0|       229.0|        9.05|
|            1.0|       229.0|         7.0|        18.5|
|            2.0|       141.0|       234.0|        13.0|
+---------------+------------+------------+------------+
only showing top 10 rows





There are 2920849 rows in training and 730150 in test


                                                                                

+---------------+------------+------------+------------------+
|passenger_count|PULocationID|DOLocationID|        prediction|
+---------------+------------+------------+------------------+
|            0.0|         4.0|         4.0|17.837203422019126|
|            0.0|         4.0|        33.0|17.837203422019126|
|            0.0|         4.0|        68.0|17.837203422019126|
|            0.0|         4.0|        79.0|17.837203422019126|
|            0.0|         4.0|       125.0|17.837203422019126|
|            0.0|         4.0|       170.0|17.837203422019126|
|            0.0|         7.0|         7.0|17.837203422019126|
|            0.0|         7.0|         7.0|17.837203422019126|
|            0.0|         7.0|       112.0|17.837203422019126|
|            0.0|         7.0|       138.0|17.837203422019126|
+---------------+------------+------------+------------------+
only showing top 10 rows





Root Mean Squared Error: 24.659222998956626


                                                                                