In [191]:
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.ml.regression import LinearRegression
from pyspark.sql import SparkSession
from pyspark.sql.functions import expm1
from pyspark.sql.functions import log1p
from pyspark.sql.functions import col

In [192]:
spark = (
    SparkSession.builder.appName("MAST30034 Tutorial 1")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "15g")
    .config("spark.sql.parquet.enableVectorizedReader","false")
    .getOrCreate()
)

In [193]:
data = spark.read.parquet('../data/combined_data/')
data.columns

['PULocationID',
 'hourly_timestamp',
 'pickup_hour_of_day',
 'pickup_day_of_week',
 'pickup_month',
 'pickup_borough',
 'num_trips',
 'pickup_num_stations',
 'pickup_daytime_routes']

In [194]:
data = data.withColumn("log_num_trips", log1p("num_trips"))

In [195]:
data = data.filter(col("PULocationID")== 100)

In [196]:
train_data = data.filter(data['pickup_month'] <= 5)  # Months 1-5 for training
test_data = data.filter(data['pickup_month'] > 5)    # Month 6 for testing

In [197]:
feature_columns = [
    'pickup_hour_of_day',
    'pickup_day_of_week',
    'pickup_num_stations',
    'pickup_daytime_routes'
]

In [198]:
# Index and One-Hot Encode categorical features if necessary
indexers = [
    StringIndexer(inputCol=column, outputCol=f"{column}_index").setHandleInvalid("keep")
    for column in ['pickup_hour_of_day', 'pickup_day_of_week']
]

encoders = [
    OneHotEncoder(inputCol=f"{column}_index", outputCol=f"{column}_ohe").setHandleInvalid("keep")
    for column in ['pickup_hour_of_day', 'pickup_day_of_week']
]

In [199]:
assembler = VectorAssembler(
    inputCols=[
        'pickup_hour_of_day_ohe',
        'pickup_day_of_week_ohe',
        'pickup_num_stations',
        'pickup_daytime_routes'
    ],
    outputCol="features"
).setHandleInvalid("keep")

In [200]:
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")

In [201]:
lr = LinearRegression(featuresCol='scaled_features', labelCol='log_num_trips')

In [202]:
pipeline = Pipeline(stages=indexers + encoders + [assembler, scaler, lr])

In [203]:
model = pipeline.fit(train_data)

24/08/21 10:37:29 WARN Instrumentation: [68097103] regParam is zero, which might cause numerical instability and overfitting.
24/08/21 10:37:29 WARN Instrumentation: [68097103] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.


In [204]:
predictions = model.transform(test_data)
predictions = predictions.withColumn("predicted_trip_count", expm1(predictions["prediction"]))

In [205]:
evaluator = RegressionEvaluator(
    labelCol="num_trips", predictionCol="predicted_trip_count", metricName="rmse")

rmse = evaluator.evaluate(predictions)
r2 = evaluator.evaluate(predictions, {evaluator.metricName: "r2"})

print(f"Root Mean Squared Error (RMSE) on test data = {rmse}")
print(f"R2 on test data = {r2}")

Root Mean Squared Error (RMSE) on test data = 56.5345818534193
R2 on test data = 0.4963869905948973


In [206]:
predictions.select('pickup_hour_of_day', 'pickup_day_of_week', 'pickup_month', 'pickup_num_stations', 'pickup_daytime_routes', 'num_trips', 'predicted_trip_count').show()

+------------------+------------------+------------+-------------------+---------------------+---------+--------------------+
|pickup_hour_of_day|pickup_day_of_week|pickup_month|pickup_num_stations|pickup_daytime_routes|num_trips|predicted_trip_count|
+------------------+------------------+------------+-------------------+---------------------+---------+--------------------+
|                20|                 2|           6|                  5|                   15|      266|  167.15795632577772|
|                 0|                 5|           6|                  5|                   15|      139|  140.59345846088226|
|                 3|                 3|           6|                  5|                   15|       22|   33.08956689445341|
|                 0|                 6|           6|                  5|                   15|      213|    146.032548120454|
|                20|                 7|           6|                  5|                   15|      211|    226.990638