# Model Development Data Prep

Notebook to join events, weather and tax curated data final features for surge prediction

In [49]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from pyspark.sql.window import Window

In [3]:
spark = SparkSession.builder \
    .appName("Final Project Advanced Model") \
    .config("spark.jars", "gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar") \
    .getOrCreate()

**Data Ingestion**

In [4]:
bucket = "msca-bdp-student-gcs/Group_1_final_project"
prefix = "curated"               

weather_path = f"gs://{bucket}/{prefix}/weather_curated"
events_path = f"gs://{bucket}/{prefix}/events_curated"
taxi_path = f"gs://{bucket}/{prefix}/taxi_curated"

In [32]:
weather_features = spark.read.parquet(weather_path)
events_features = spark.read.parquet(events_path)
taxi_features = spark.read.parquet(taxi_path)

**Final Fixes**

In [33]:
taxi_features = taxi_features.withColumn("PULocationID", F.col("PULocationID").cast(IntegerType()))
events_features = events_features.withColumn("PULocationID", F.col("PULocationID").cast(IntegerType()))
weather_features = weather_features.withColumn("PULocationID", F.col("PULocationID").cast(IntegerType()))

In [34]:
# Taxi
taxi_features = taxi_features.withColumn(
    "half_hour",
    F.to_timestamp("half_hour", "dd/MM/yyyy HH:mm")
)

# Events
events_features = events_features.withColumn(
    "half_hour",
    F.to_timestamp("half_hour", "dd/MM/yyyy HH:mm")
)

# Weather – assume column already has correct time; if it's only date, this will at least cast it
weather_features = weather_features.withColumn(
    "half_hour",
    F.col("half_hour").cast("timestamp")
)

In [35]:
taxi_features.select("half_hour").orderBy("half_hour").show(3, False)



+-------------------+
|half_hour          |
+-------------------+
|2021-12-31 19:00:00|
|2021-12-31 19:00:00|
|2021-12-31 19:00:00|
+-------------------+
only showing top 3 rows



                                                                                

In [36]:
events_features.select("half_hour").orderBy("half_hour").show(3, False)

+-------------------+
|half_hour          |
+-------------------+
|2022-01-01 00:00:00|
|2022-01-01 00:00:00|
|2022-01-01 00:00:00|
+-------------------+
only showing top 3 rows



                                                                                

In [37]:
weather_features.select("half_hour").orderBy("half_hour").show(3, False)

+-------------------+
|half_hour          |
+-------------------+
|2022-01-01 00:00:00|
|2022-01-01 00:00:00|
|2022-01-01 00:00:00|
+-------------------+
only showing top 3 rows



In [39]:
# Rename event_borough to avoid clashing with taxi.borough
# Rename weather borough so we can keep it for EDA if needed
weather_features = weather_features.withColumnRenamed("borough", "weather_borough")

In [42]:
# Keys
keys = ["PULocationID", "half_hour"]

# Columns that exist in both taxi and events (except keys) → drop from events
overlap_events = [
    c for c in events_features.columns
    if c in taxi_features.columns and c not in keys
]

events_clean = events_features.drop(*overlap_events)

# Columns that exist in both taxi and weather (except keys) → drop from weather
overlap_weather = [
    c for c in weather_features.columns
    if c in taxi_features.columns and c not in keys
]

weather_clean = weather_features.drop(*overlap_weather)

**Final Join**

In [43]:
# Event feature columns (everything except keys)
event_feature_cols = [
    c for c in events_clean.columns
    if c not in keys
]

# Weather feature columns (everything except keys)
weather_feature_cols = [
    c for c in weather_clean.columns
    if c not in keys
]

df_final = (
    taxi_features
    .join(
        events_clean.select(*(keys + event_feature_cols)),
        on=keys,
        how="left"
    )
    .join(
        weather_clean.select(*(keys + weather_feature_cols)),
        on=keys,
        how="left"
    )
)

In [44]:
# Identify numeric event columns from the cleaned events df
event_numeric_cols = [
    name for name, dtype in events_clean.dtypes
    if name not in keys and dtype in ("int", "bigint", "double", "float", "long", "short")
]

df_final = df_final.fillna(0, subset=event_numeric_cols)

In [45]:
# Identify numeric event columns from the cleaned events df
event_numeric_cols = [
    name for name, dtype in events_clean.dtypes
    if name not in keys and dtype in ("int", "bigint", "double", "float", "long", "short")
]

df_final = df_final.fillna(0, subset=event_numeric_cols)

**Spot Check Data**

In [47]:
df_final.filter(
    F.col("PULocationID").isNull() | F.col("half_hour").isNull()
).count()

                                                                                

0

In [46]:
sample_zone = 234  # pick a zone with a lot of trips
df_final.filter(F.col("PULocationID") == sample_zone) \
        .select("half_hour", "trip_count", "total_events", "temp_avg", "is_precip") \
        .orderBy("half_hour") \
        .show(20, False)

[Stage 14:>                                                         (0 + 8) / 9]

+-------------------+----------+------------+--------+---------+
|half_hour          |trip_count|total_events|temp_avg|is_precip|
+-------------------+----------+------------+--------+---------+
|2021-12-31 19:00:00|5         |0           |null    |null     |
|2021-12-31 19:30:00|7         |0           |null    |null     |
|2021-12-31 20:00:00|4         |0           |null    |null     |
|2021-12-31 20:30:00|2         |0           |null    |null     |
|2021-12-31 21:00:00|1         |0           |null    |null     |
|2022-01-01 12:00:00|2         |6           |52.3    |0        |
|2022-01-01 12:30:00|4         |6           |52.3    |0        |
|2022-01-01 13:00:00|3         |6           |52.3    |0        |
|2022-01-01 13:30:00|3         |6           |52.3    |0        |
|2022-01-01 14:00:00|4         |6           |52.3    |0        |
|2022-01-01 14:30:00|1         |6           |52.3    |0        |
|2022-01-02 08:00:00|3         |3           |53.2    |1        |
|2022-01-02 09:30:00|2   

                                                                                

In [51]:
final_path = "gs://msca-bdp-student-gcs/Group_1_final_project/curated/final_features"   # TODO: replace

df_final.write.mode("overwrite").parquet(final_path)

                                                                                