In [8]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator

# Start Spark session
spark = SparkSession.builder \
    .appName("DecisionTreeRegression") \
    .getOrCreate()

# Load the dataset
path = "gs://dataproc-staging-us-central1-762478826489-su9ldjls/2019-01-h1.csv"
df = spark.read.csv(path, header=True, inferSchema=True)

# Select the required columns
df_selected = df.select("passenger_count", "pulocationid", "dolocationid", "total_amount")

# Show the first 10 entries
df_selected.show(10)

# Split into training and testing sets
trainDF, testDF = df_selected.randomSplit([0.8, 0.2], seed=42)

# Assemble features into a single vector
assembler = VectorAssembler(
    inputCols=["passenger_count", "pulocationid", "dolocationid"],
    outputCol="features"
)

# Define the Decision Tree Regressor
dt = DecisionTreeRegressor(
    featuresCol="features", 
    labelCol="total_amount"
)

# Set max bins (important for DecisionTree)
dt = dt.setMaxBins(1000)

# Create a pipeline
pipeline = Pipeline(stages=[assembler, dt])

# Train the model
model = pipeline.fit(trainDF)

# Make predictions
predictions = model.transform(testDF)

# Show the predictions along with features
predictions.select("passenger_count", "pulocationid", "dolocationid", "total_amount", "prediction").show(10)

# Evaluate the model
evaluator = RegressionEvaluator(
    labelCol="total_amount", 
    predictionCol="prediction", 
    metricName="rmse"
)

rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE) on test data = {rmse:.4f}")

# Stop Spark session
spark.stop()


25/04/26 05:42:51 INFO SparkEnv: Registering MapOutputTracker
25/04/26 05:42:51 INFO SparkEnv: Registering BlockManagerMaster
25/04/26 05:42:51 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/04/26 05:42:51 INFO SparkEnv: Registering OutputCommitCoordinator
                                                                                

+---------------+------------+------------+------------+
|passenger_count|pulocationid|dolocationid|total_amount|
+---------------+------------+------------+------------+
|            1.0|       151.0|       239.0|        9.95|
|            1.0|       239.0|       246.0|        16.3|
|            3.0|       236.0|       236.0|         5.8|
|            5.0|       193.0|       193.0|        7.55|
|            5.0|       193.0|       193.0|       55.55|
|            5.0|       193.0|       193.0|       13.31|
|            5.0|       193.0|       193.0|       55.55|
|            1.0|       163.0|       229.0|        9.05|
|            1.0|       229.0|         7.0|        18.5|
|            2.0|       141.0|       234.0|        13.0|
+---------------+------------+------------+------------+
only showing top 10 rows



                                                                                

+---------------+------------+------------+------------+------------------+
|passenger_count|pulocationid|dolocationid|total_amount|        prediction|
+---------------+------------+------------+------------+------------------+
|            0.0|         4.0|         4.0|         4.3|17.837203422019126|
|            0.0|         4.0|        33.0|       17.75|17.837203422019126|
|            0.0|         4.0|        68.0|        15.8|17.837203422019126|
|            0.0|         4.0|        79.0|        9.75|17.837203422019126|
|            0.0|         4.0|       125.0|         9.3|17.837203422019126|
|            0.0|         4.0|       170.0|       11.15|17.837203422019126|
|            0.0|         7.0|         7.0|        0.31|17.837203422019126|
|            0.0|         7.0|         7.0|         6.3|17.837203422019126|
|            0.0|         7.0|       112.0|        16.8|17.837203422019126|
|            0.0|         7.0|       138.0|        10.8|17.837203422019126|
+-----------



Root Mean Squared Error (RMSE) on test data = 24.6592



                                                                                