In [16]:
import os
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator

# Initializing Spark
spark = SparkSession.builder.appName("TaxiML").getOrCreate()

# This is to suppress Spark INFO/WARN logs for cleaner output
spark.sparkContext.setLogLevel("ERROR")

df = spark.read.csv("2019-04.csv", header=True, inferSchema=True)

# Create a dataset that only contains passenger_count (4th col), pulocationid (8th col), dolocationid (9th col), and total_amount (17th col) based on the 2019-04.csv dataset. 

selected_df = df.select(
    df.columns[3],  # passenger_count
    df.columns[7],  # pulocationid
    df.columns[8],  # dolocationid
    df.columns[16]  # total_amount
).dropna()

# Show the first 10 entries in the created dataset.
print("First 10 entries of the dataset:")
selected_df.show(10)

# Creating trainDF and testDF. Split sets (80/20)
trainDF, testDF = selected_df.randomSplit([0.8, 0.2], seed=42)

# Prepare features vector
assembler = VectorAssembler(
    inputCols=[df.columns[3], df.columns[7], df.columns[8]],
    outputCol="features"
)

# Create a decision tree regressor to predict total_amount from the other three features.
dt = DecisionTreeRegressor(
    featuresCol="features",
    labelCol="total_amount",
    maxDepth=10,
    minInstancesPerNode=20
)

# Create pipeline
pipeline = Pipeline(stages=[assembler, dt])

# Train the model
model = pipeline.fit(trainDF)

# Show the predicted results along with the three features in the notebook. 
predictions = model.transform(testDF)

# Show first 10 predictions with features
print("First 10 predictions:")
predictions.select(
    df.columns[3], df.columns[7], df.columns[8], df.columns[16], "prediction"
).show(10)

# Evaluate with RMSE
evaluator = RegressionEvaluator(
    labelCol=df.columns[16],
    predictionCol="prediction",
    metricName="rmse"
)
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE) on test data = {rmse:.4f}")


                                                                                

First 10 entries of the dataset:
+---------------+------------+------------+------------+
|passenger_count|pulocationid|dolocationid|total_amount|
+---------------+------------+------------+------------+
|            1.0|       239.0|       239.0|         8.8|
|            1.0|       230.0|       100.0|         8.3|
|            1.0|        68.0|       127.0|       47.75|
|            1.0|        68.0|        68.0|         7.3|
|            1.0|        50.0|        42.0|       23.15|
|            1.0|        95.0|       196.0|         9.8|
|            1.0|       211.0|       211.0|         6.8|
|            1.0|       237.0|       162.0|         7.8|
|            1.0|       148.0|        37.0|        20.3|
|            1.0|       265.0|       265.0|        0.31|
+---------------+------------+------------+------------+
only showing top 10 rows


                                                                                

First 10 predictions:


                                                                                

+---------------+------------+------------+------------+------------------+
|passenger_count|pulocationid|dolocationid|total_amount|        prediction|
+---------------+------------+------------+------------+------------------+
|            0.0|         1.0|         1.0|       103.3|32.160820895522384|
|            0.0|         4.0|         4.0|         6.8|32.160820895522384|
|            0.0|         4.0|        33.0|       31.55| 19.49882598124357|
|            0.0|         4.0|        79.0|         7.8|18.576614979520222|
|            0.0|         4.0|       107.0|        11.8| 24.33321598826554|
|            0.0|         4.0|       144.0|        11.3|15.865456332145953|
|            0.0|         4.0|       234.0|        11.0|22.176947630922726|
|            0.0|         7.0|       121.0|        28.8|30.028722115997393|
|            0.0|         7.0|       223.0|         6.8|21.074034556944994|
|            0.0|         7.0|       223.0|         8.3|21.074034556944994|
+-----------

[Stage 78:>                                                         (0 + 8) / 8]

Root Mean Squared Error (RMSE) on test data = 12.5084


                                                                                