# MLP Model Development

In [85]:
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.ml.functions import vector_to_array

from pyspark.ml import Pipeline
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [53]:
spark = SparkSession.builder \
    .appName("MLP Model") \
    .config("spark.jars", "gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar") \
    .getOrCreate()

In [54]:
final_path = "gs://msca-bdp-student-gcs/Group_1_final_project/curated/final_features" 
events_taxi_df = spark.read.parquet(final_path)

**Get the most busy zone**

In [55]:
# Compute per-zone 90th percentile of trip_count
zone_q = (
    events_taxi_df
    .groupBy("PULocationID")
    .agg(F.expr("percentile_approx(trip_count, 0.90)").alias("q90"))
)

events_taxi_surge = events_taxi_df.join(zone_q, "PULocationID", "left")

# Binary surge label: 1 if trip_count > zone-specific 90th percentile
events_taxi_surge = events_taxi_surge.withColumn(
    "is_surge",
    (F.col("trip_count") > F.col("q90")).cast("int")
)

### Class Balancing

In [56]:
label_col = "is_surge"

In [57]:
drop_cols = [
    "borough",
    "event_borough",
    "weather_borough",
    "station_name",
    "season",
    "weather_date",
    "date"
]

events_taxi_surge_final = events_taxi_surge.drop(*[c for c in drop_cols if c in events_taxi_surge.columns])

### Lag Features

In [58]:
w = Window.partitionBy("PULocationID").orderBy("half_hour")

events_taxi_surge_final = (
    events_taxi_surge_final
    .withColumn("lag1_trip_count", F.lag("trip_count", 1).over(w))
    .withColumn("lag2_trip_count", F.lag("trip_count", 2).over(w))
    .withColumn("lag4_trip_count", F.lag("trip_count", 4).over(w))
    .withColumn("rolling_mean_2hr",
        F.avg("trip_count").over(w.rowsBetween(-4, -1))
    )
)

events_taxi_surge_final = events_taxi_surge_final.fillna(
    {
        "lag1_trip_count": 0.0,
        "lag2_trip_count": 0.0,
        "lag4_trip_count": 0.0,
        "rolling_mean_2hr": 0.0
    }
)

In [59]:
# Ensure no null labels
events_taxi_surge_final = (
    events_taxi_surge_final
    .filter(F.col("trip_count").isNotNull())
    .filter(F.col("is_surge").isNotNull())
)

In [60]:
# Identify numeric columns to impute (exclude labels)
numeric_impute_cols = [
    c for c, t in events_taxi_surge_final.dtypes
    if t in ("int", "bigint", "double") and c not in ("trip_count", "is_surge")
]

print("Imputing nulls in numeric columns:", len(numeric_impute_cols))

events_taxi_surge_final = events_taxi_surge_final.fillna(0, subset=numeric_impute_cols)

Imputing nulls in numeric columns: 56


### Train, Test Split

In [75]:
train_cutoff = "2024-01-01"
test_cutoff   = "2024-03-01"

train_df = events_taxi_surge_final.filter(F.col("half_hour") < train_cutoff)
test_df  = events_taxi_surge_final.filter(F.col("half_hour") >= test_cutoff)

**Compute Class weights**

In [76]:
label_col = "is_surge"

# Separate surge / non-surge in train
train_surge = train_df.filter(F.col(label_col) == 1)
train_nonsurge = train_df.filter(F.col(label_col) == 0)

n_pos = train_surge.count()
n_neg = train_nonsurge.count()

print("Train counts before balancing:", {"0": n_neg, "1": n_pos})

# e.g. keep about 3x as many non-surge as surge
target_ratio = 3.0
fraction_neg = min(1.0, (target_ratio * n_pos) / float(n_neg))

train_nonsurge_down = train_nonsurge.sample(
    withReplacement=False,
    fraction=fraction_neg,
    seed=42
)

train_balanced = train_surge.union(train_nonsurge_down)

n_pos_b = train_balanced.filter(F.col(label_col) == 1).count()
n_neg_b = train_balanced.filter(F.col(label_col) == 0).count()
print("Train counts after balancing:", {"0": n_neg_b, "1": n_pos_b})

                                                                                

Train counts before balancing: {'0': 1412253, '1': 109374}




Train counts after balancing: {'0': 328489, '1': 109374}


                                                                                

### Feature Selection

In [77]:
events_taxi_surge_final.dtypes

[('PULocationID', 'int'),
 ('half_hour', 'timestamp'),
 ('trip_count', 'bigint'),
 ('total_passengers', 'bigint'),
 ('avg_passenger_count', 'double'),
 ('avg_trip_distance', 'double'),
 ('avg_speed_mph', 'double'),
 ('avg_fare_amount', 'double'),
 ('avg_total_amount', 'double'),
 ('avg_tip_rate', 'double'),
 ('avg_fare_per_mile', 'double'),
 ('sum_total_fees', 'double'),
 ('num_high_ppm_trips', 'bigint'),
 ('num_extreme_speed_trips', 'bigint'),
 ('num_near_station_pickups', 'bigint'),
 ('avg_hour_sin', 'double'),
 ('avg_hour_cos', 'double'),
 ('day_of_week', 'int'),
 ('is_weekend', 'int'),
 ('month', 'int'),
 ('hour_of_day', 'int'),
 ('total_events', 'bigint'),
 ('avg_event_duration_min', 'double'),
 ('total_event_importance', 'double'),
 ('event_start_flag', 'int'),
 ('event_end_flag', 'int'),
 ('miscellaneous', 'bigint'),
 ('lat', 'double'),
 ('lon', 'double'),
 ('elevation_m', 'double'),
 ('temp_avg', 'double'),
 ('dew_point', 'double'),
 ('pressure_sea_level', 'double'),
 ('visibil

In [78]:
categorical_feats = [
    "PULocationID",
    "day_of_week",
    "month",
    "hour_of_day"
]

In [79]:
numeric_feats = [
    # taxi stats (no trip_count)
    "total_passengers",
    "avg_passenger_count",
    "avg_trip_distance",
    "avg_speed_mph",
    "avg_fare_amount",
    "avg_total_amount",
    "avg_tip_rate",
    "avg_fare_per_mile",
    "sum_total_fees",
    "num_high_ppm_trips",
    "num_extreme_speed_trips",
    "num_near_station_pickups",

    # time encodings
    "avg_hour_sin",
    "avg_hour_cos",

    # events
    "total_events",
    "avg_event_duration_min",
    "total_event_importance",
    "miscellaneous",

    # geo
    "lat", "lon", "elevation_m",

    # weather continuous
    "temp_avg", "dew_point", "pressure_sea_level",
    "visibility_mi", "wind_speed_avg", "wind_speed_max",
    "wind_gust", "temp_max", "temp_min",
    "precip_in", "snow_depth_in",
    "temp_range", "wind_ratio", "wind_severity",

    # lags / rolling
    "lag1_trip_count",
    "lag2_trip_count",
    "lag4_trip_count",
    "rolling_mean_2hr"
]

In [80]:
binary_feats = [
    "is_weekend",
    "event_start_flag",
    "event_end_flag",
    "is_precip",
    "is_heavy_rain",
    "is_snow",
    "is_low_visibility",
    "is_windy",
    "is_storm_gust",
    "is_cold_snap",
    "is_heat_wave"
]

### Clean MLP Feature Pipeline

**Index + one-hot encode for categorical features**

In [81]:
# 1. Index + one-hot for categoricals
indexers = [
    StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
    for c in categorical_feats
]

encoders = [
    OneHotEncoder(
        inputCols=[f"{c}_idx"],
        outputCols=[f"{c}_oh"]
    )
    for c in categorical_feats
]

# 2. Assemble and scale numeric (continuous + binary flags)
numeric_assembler = VectorAssembler(
    inputCols=numeric_feats + binary_feats,
    outputCol="numeric_vector"
)

numeric_scaler = StandardScaler(
    inputCol="numeric_vector",
    outputCol="numeric_scaled",
    withMean=True,
    withStd=True
)

# 3. Final features vector: scaled numeric + raw one-hot
final_assembler = VectorAssembler(
    inputCols=["numeric_scaled"] + [f"{c}_oh" for c in categorical_feats],
    outputCol="features"
)

feature_stages = indexers + encoders + [numeric_assembler, numeric_scaler, final_assembler]
feature_pipeline = Pipeline(stages=feature_stages)

### Feature Vectors

In [82]:
# Fit on TRAIN ONLY
feature_model = feature_pipeline.fit(train_balanced)

# Transform train and test
train_prepared = feature_model.transform(train_balanced)
test_prepared  = feature_model.transform(test_df)

                                                                                

In [83]:
# Get feature dimension from one row
first_vec = train_prepared.select("features").first()["features"]
input_dim = len(first_vec)
print("Input dimension:", input_dim)

[Stage 151:>                                                        (0 + 1) / 1]

Input dimension: 344


                                                                                

### MLP Classifer - Surge Prediction

In [86]:
layers = [input_dim, 2 * input_dim, 2]   # simple 1-hidden-layer MLP

mlp = MultilayerPerceptronClassifier(
    labelCol=label_col,
    featuresCol="features",
    layers=layers,
    blockSize=256,
    maxIter=80,
    stepSize=0.05,
    seed=42
)

mlp_model = mlp.fit(train_prepared)

25/12/07 02:22:58 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
25/12/07 02:22:58 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS

In [87]:
model_path = "gs://msca-bdp-student-gcs/Group_1_final_project/Workspace/Data_pipelines/mlp_model"

# Save the full trained pipeline model
mlp_model.write().overwrite().save(model_path)
print("Model saved at:", model_path)

25/12/07 02:35:12 WARN org.apache.spark.scheduler.TaskSetManager: Stage 449 contains a task of very large size (1908 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Model saved at: gs://msca-bdp-student-gcs/Group_1_final_project/Workspace/Data_pipelines/mlp_model


**Evaluate on Test Set**

In [None]:
predictions = mlp_model.transform(test_prepared)

# 6.1 ROC AUC and PR AUC
evaluator_roc = BinaryClassificationEvaluator(
    labelCol=label_col,
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)
evaluator_pr = BinaryClassificationEvaluator(
    labelCol=label_col,
    rawPredictionCol="rawPrediction",
    metricName="areaUnderPR"
)

auc_roc = evaluator_roc.evaluate(predictions)
auc_pr  = evaluator_pr.evaluate(predictions)

print(f"MLP AUC-ROC: {auc_roc:.4f}")
print(f"MLP AUC-PR : {auc_pr:.4f}")

25/12/07 02:35:30 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.1 MiB
25/12/07 02:36:01 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.1 MiB
                                                                                

MLP AUC-ROC: 0.9606
MLP AUC-PR : 0.8563


In [90]:
threshold = 0.5  # adjust based on your PR curve / business trade-off

# Convert probability vector -> array, then take element 1 (class "1" = surge)
predictions = predictions.withColumn(
    "probability_arr",
    vector_to_array("probability")
).withColumn(
    "prob_surge",
    F.col("probability_arr")[1]
).withColumn(
    "pred_surge_custom",
    F.when(F.col("prob_surge") >= threshold, F.lit(1)).otherwise(F.lit(0))
)

tp = predictions.filter((F.col(label_col) == 1) & (F.col("pred_surge_custom") == 1)).count()
fp = predictions.filter((F.col(label_col) == 0) & (F.col("pred_surge_custom") == 1)).count()
tn = predictions.filter((F.col(label_col) == 0) & (F.col("pred_surge_custom") == 0)).count()
fn = predictions.filter((F.col(label_col) == 1) & (F.col("pred_surge_custom") == 0)).count()

precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1        = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

print(f"\nMLP at threshold = {threshold}")
print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1       : {f1:.4f}")

25/12/07 02:37:34 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.0 MiB
25/12/07 02:37:52 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.0 MiB
25/12/07 02:38:11 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.0 MiB
25/12/07 02:38:28 WARN org.apache.spark.scheduler.DAGScheduler: Broadcasting large task binary with size 2.0 MiB


MLP at threshold = 0.5
TP: 83905, FP: 45353, TN: 708137, FN: 14180
Precision: 0.6491
Recall   : 0.8554
F1       : 0.7381


                                                                                