# Model Development

In [2]:
# %pip install tensorflow matplotlib --quiet

In [24]:
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType
from pyspark.ml.feature import VectorAssembler

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.ml.functions import vector_to_array

from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

from pyspark.ml.regression import LinearRegression
from pyspark.ml.classification import GBTClassifier

from pyspark.ml.evaluation import RegressionEvaluator, BinaryClassificationEvaluator

In [4]:
spark = SparkSession.builder \
    .appName("Final Project Advanced Model") \
    .config("spark.jars", "gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar") \
    .getOrCreate()

In [5]:
final_path = "gs://msca-bdp-student-gcs/Group_1_final_project/curated/final_features" 
events_taxi_df = spark.read.parquet(final_path)

**Get the most busy zone**

In [4]:
# events_taxi_df.groupBy("PULocationID").agg(
#         F.sum(F.col("trip_count")).alias("sum")
#     ).orderBy("sum", ascending=False).show()



+------------+-------+
|PULocationID|    sum|
+------------+-------+
|         161|2455341|
|         237|2280101|
|         236|1835086|
|         186|1782760|
|         162|1742994|
|         230|1707065|
|         138|1666951|
|         163|1472952|
|         142|1439463|
|         170|1410807|
|         234|1343727|
|          68|1322331|
|          48|1200222|
|         239|1157558|
|          43|1132719|
|         164|1112757|
|         249|1033037|
|          79|1006890|
|         141| 945925|
|         132| 932973|
+------------+-------+
only showing top 20 rows



                                                                                

In [6]:
# Compute per-zone 90th percentile of trip_count
zone_q = (
    events_taxi_df
    .groupBy("PULocationID")
    .agg(F.expr("percentile_approx(trip_count, 0.90)").alias("q90"))
)

events_taxi_surge = events_taxi_df.join(zone_q, "PULocationID", "left")

# Binary surge label: 1 if trip_count > zone-specific 90th percentile
events_taxi_surge = events_taxi_surge.withColumn(
    "is_surge",
    (F.col("trip_count") > F.col("q90")).cast("int")
)

### Class Balancing

In [7]:
# df_surge has the column "is_surge" (0 or 1)
class_counts = events_taxi_surge.groupBy("is_surge").count()

class_counts.show()



+--------+-------+
|is_surge|  count|
+--------+-------+
|       1| 221468|
|       0|2310894|
+--------+-------+



                                                                                

In [8]:
counts = class_counts.rdd.collectAsMap()

n_neg = counts.get(0, 0)
n_pos = counts.get(1, 0)

print("Negative (no surge):", n_neg)
print("Positive (surge):   ", n_pos)

                                                                                

Negative (no surge): 2310894
Positive (surge):    221468


In [9]:
pos_weight = n_neg / n_pos
print("Positive class weight =", pos_weight)

Positive class weight = 10.43443748080987


In [10]:
events_taxi_surge = events_taxi_surge.withColumn(
    "class_weight",
    F.when(F.col("is_surge") == 1, F.lit(pos_weight)).otherwise(F.lit(1.0))
)

In [11]:
drop_cols = [
    "borough",
    "event_borough",
    "weather_borough",
    "station_name",
    "season",
    "weather_date",
    "date"
]

events_taxi_surge_final = events_taxi_surge.drop(*[c for c in drop_cols if c in events_taxi_surge.columns])

### Lag Features

In [12]:
w = Window.partitionBy("PULocationID").orderBy("half_hour")

events_taxi_surge_final = (
    events_taxi_surge_final
    .withColumn("lag1_trip_count", F.lag("trip_count", 1).over(w))
    .withColumn("lag2_trip_count", F.lag("trip_count", 2).over(w))
    .withColumn("lag4_trip_count", F.lag("trip_count", 4).over(w))
    .withColumn("rolling_mean_2hr",
        F.avg("trip_count").over(w.rowsBetween(-4, -1))
    )
)

events_taxi_surge_final = events_taxi_surge_final.fillna(
    {
        "lag1_trip_count": 0.0,
        "lag2_trip_count": 0.0,
        "lag4_trip_count": 0.0,
        "rolling_mean_2hr": 0.0
    }
)

In [13]:
# Ensure no null labels
events_taxi_surge_final = (
    events_taxi_surge_final
    .filter(F.col("trip_count").isNotNull())
    .filter(F.col("is_surge").isNotNull())
)

### Train Test Split

In [14]:
# Identify numeric columns to impute (exclude labels)
numeric_impute_cols = [
    c for c, t in events_taxi_surge_final.dtypes
    if t in ("int", "bigint", "double") and c not in ("trip_count", "is_surge")
]

print("Imputing nulls in numeric columns:", len(numeric_impute_cols))

events_taxi_surge_final = events_taxi_surge_final.fillna(0, subset=numeric_impute_cols)

Imputing nulls in numeric columns: 57


In [15]:
# Columns that should NEVER go into features
exclude_cols = {
    "trip_count",      # regression label
    "is_surge",        # classification label
    "PULocationID",    # ID
    "half_hour",       # timestamp key
    "q90",             # threshold used to define is_surge
    "class_weight"     # sample weight, not feature
}

# Build feature_cols ON THE FULL df_ts
feature_cols = [
    c for c, t in events_taxi_surge_final.dtypes
    if t not in ("string", "timestamp") and c not in exclude_cols
]

print("Feature columns:", len(feature_cols))

Feature columns: 54


In [16]:
train_cutoff = "2024-01-01"
val_cutoff   = "2024-03-01"

train_df = events_taxi_surge_final.filter(F.col("half_hour") < train_cutoff)
val_df   = events_taxi_surge_final.filter((F.col("half_hour") >= train_cutoff) & (F.col("half_hour") < val_cutoff))
test_df  = events_taxi_surge_final.filter(F.col("half_hour") >= val_cutoff)

### Feature Vectors

In [17]:
assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features",
    handleInvalid="keep"   # extra guardrail
)

### GBT Classifer - Surge Prediction

In [18]:
gbt_cls = GBTClassifier(
    featuresCol="features",
    labelCol="is_surge",
    weightCol="class_weight",   # key for imbalance
    maxDepth=8,
    maxIter=80,
    stepSize=0.05,
    subsamplingRate=0.8
)

gbt_cls_pipeline = Pipeline(stages=[assembler, gbt_cls])
gbt_cls_model = gbt_cls_pipeline.fit(train_df)

25/12/01 00:44:17 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1005.3 KiB
25/12/01 00:44:17 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1011.0 KiB
25/12/01 00:44:18 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1011.5 KiB
25/12/01 00:44:18 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1012.2 KiB
25/12/01 00:44:18 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1013.2 KiB
25/12/01 00:44:18 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1015.5 KiB
25/12/01 00:44:19 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1020.1 KiB
25/12/01 00:44:19 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 1029.4 KiB
25/12/01 00:44:19 WARN org.apache.spark.scheduler.DAGScheduler: Broadcas

In [19]:
model_path = "gs://msca-bdp-student-gcs/Group_1_final_project/Workspace/Data_pipelines/gbt_surge_model"

# Save the full trained pipeline model
gbt_cls_model.write().overwrite().save(model_path)
print("Model saved at:", model_path)

                                                                                

Model saved at: gs://msca-bdp-student-gcs/Group_1_final_project/Workspace/Data_pipelines/gbt_surge_model


In [20]:
val_preds_cls = gbt_cls_model.transform(val_df)

evaluator_auc = BinaryClassificationEvaluator(
    labelCol="is_surge",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

val_auc = evaluator_auc.evaluate(val_preds_cls)
print("GBTClassifier – Validation AUC:", val_auc)

25/12/01 00:47:51 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.5 MiB
                                                                                

GBTClassifier – Validation AUC: 0.9560734062661168


In [25]:
val_probs = (
    val_preds_cls
    .withColumn("prob_array", vector_to_array("probability"))
    .select("is_surge", "prob_array")
)
eps = 1e-9

def metrics_for_threshold(th):
    preds_th = val_probs.withColumn(
        "pred_is_surge",
        (F.col("prob_array")[1] > F.lit(th)).cast("int")   # prob of class 1
    )
    
    tp = preds_th.filter("pred_is_surge = 1 AND is_surge = 1").count()
    fp = preds_th.filter("pred_is_surge = 1 AND is_surge = 0").count()
    fn = preds_th.filter("pred_is_surge = 0 AND is_surge = 1").count()
    
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = 2 * precision * recall / (precision + recall + eps)
    
    return precision, recall, f1

thresholds = [0.2, 0.3, 0.4, 0.5]

for th in thresholds:
    p, r, f1 = metrics_for_threshold(th)
    print(f"th={th:.2f}  precision={p:.3f}  recall={r:.3f}  f1={f1:.3f}")

25/12/01 00:50:58 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:11 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:26 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
                                                                                

th=0.20  precision=0.248  recall=0.982  f1=0.396


25/12/01 00:51:30 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:34 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:37 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
                                                                                

th=0.30  precision=0.296  recall=0.965  f1=0.454


25/12/01 00:51:41 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:44 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:48 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
                                                                                

th=0.40  precision=0.342  recall=0.945  f1=0.502


25/12/01 00:51:52 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:55 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:51:59 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB

th=0.50  precision=0.387  recall=0.913  f1=0.544


                                                                                

In [26]:
from pyspark.ml.functions import vector_to_array
from pyspark.sql import functions as F

best_th = 0.5  # chosen based on F1

val_with_probs = val_preds_cls.withColumn(
    "prob_array", vector_to_array("probability")
)

val_with_preds = val_with_probs.withColumn(
    "pred_is_surge_gbt",
    (F.col("prob_array")[1] > F.lit(best_th)).cast("int")
)

eps = 1e-9

tp = val_with_preds.filter("pred_is_surge_gbt = 1 AND is_surge = 1").count()
fp = val_with_preds.filter("pred_is_surge_gbt = 1 AND is_surge = 0").count()
fn = val_with_preds.filter("pred_is_surge_gbt = 0 AND is_surge = 1").count()

precision = tp / (tp + fp + eps)
recall    = tp / (tp + fn + eps)
f1        = 2 * precision * recall / (precision + recall + eps)

print("FINAL (validation) – th =", best_th)
print("Precision:", precision)
print("Recall:   ", recall)
print("F1:       ", f1)

25/12/01 00:53:17 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:53:22 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:53:27 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB

FINAL (validation) – th = 0.5
Precision: 0.3872531200775359
Recall:    0.9125562138624518
F1:        0.5437570443063383


                                                                                

In [27]:
# Get test predictions from your saved/loaded model
test_preds_cls = gbt_cls_model.transform(test_df)

test_with_probs = test_preds_cls.withColumn(
    "prob_array", vector_to_array("probability")
)

test_with_preds = test_with_probs.withColumn(
    "pred_is_surge_gbt",
    (F.col("prob_array")[1] > F.lit(best_th)).cast("int")
)

tp_t = test_with_preds.filter("pred_is_surge_gbt = 1 AND is_surge = 1").count()
fp_t = test_with_preds.filter("pred_is_surge_gbt = 1 AND is_surge = 0").count()
fn_t = test_with_preds.filter("pred_is_surge_gbt = 0 AND is_surge = 1").count()

precision_t = tp_t / (tp_t + fp_t + eps)
recall_t    = tp_t / (tp_t + fn_t + eps)
f1_t        = 2 * precision_t * recall_t / (precision_t + recall_t + eps)

print("TEST – th =", best_th)
print("Precision:", precision_t)
print("Recall:   ", recall_t)
print("F1:       ", f1_t)

25/12/01 00:53:54 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:54:14 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/12/01 00:54:23 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 3.4 MiB

TEST – th = 0.5
Precision: 0.4530585524779987
Recall:    0.9102003364428719
F1:        0.6049827365038947


                                                                                

## Model Insights from Test set

**Recall is extremely high (91%)**

The model correctly identifies 9 out of 10 actual surge periods, which is exactly what we’d want in a real ride-hailing system:
- Missing true surges leads to underpricing + lost revenue + unhappy drivers
- Flagging too many false surges leadss to light inefficiencies, but tolerable

So high recall is desirable.

**Precision is very reasonable (45%) in an imbalanced setting**

Given surge is only ~8–10% of the data, a naïve model might have precision below 10%.

45% precision is actually very good:
- When the model predicts “surge”, 45% of the time it’s correct.
- That’s more than 4× better than random guessing at the base rate.

**F1 score (0.605) is strong**

An F1 ≈ 0.60 is quite solid for:
- imbalanced data
- time-series structured
- noisy weather + event + taxi interactions
- only half-hour granularity

### Surge Prediction Model Results

We evaluated the final surge classifier on the held-out test set using the tuned threshold of 0.50, selected during validation based on the F1 score. The model achieved:
- Precision: 0.453
- Recall: 0.910
- F1 Score: 0.605

Despite the underlying class imbalance (surge represents only ~9% of all half-hour × zone observations), the model demonstrates strong performance. The high recall indicates that the classifier successfully identifies the vast majority of actual surge periods (≈91%), which aligns with the operational objective of minimizing missed surges. While precision is lower (≈45%), this is substantially above the surge base rate and reflects acceptable trade-offs: falsely predicting a surge primarily leads to conservative staffing or pricing adjustments, whereas failing to detect a true surge can cause supply shortages and revenue loss.

The F1 score of 0.605 represents a solid balance between precision and recall and shows no evidence of overfitting, improving from validation (0.544) to test (0.605). This suggests that the model generalizes well and captures meaningful signals from weather, events, and lagged taxi activity patterns.