# Real-Time Foot Traffic Prediction using Spark Streaming + Kafka + Trained Random Forest Models

This script connects to a live Kafka stream (`mta_turnstile_topic`), extracts and parses NYC turnstile station data, and applies pre-trained Random Forest regression models to **predict ENTRIES and EXITS** for each station-hour-day combination in real-time.

Key components:
- **Kafka Stream Ingestion** (using `spark-sql-kafka`)
- **Timestamp Feature Extraction** (`hour`, `day_of_week`)
- **Batch Inference** using saved `PipelineModel`s from prior training
- **Streaming Output** printed using `joined_pred.show()` for debugging

This pipeline enables predictive monitoring of NYC subway foot traffic as data flows in.


In [2]:
# Stream Prediction Script (fixed to match ENTRY_COUNT/EXIT_COUNT training pipeline)
from pyspark.sql import SparkSession
from pyspark.sql.functions import split
from pyspark.ml import PipelineModel
from pyspark.sql.functions import to_timestamp, hour, date_format, col, concat_ws, expr

## Connect to Kafka Topic

In [4]:
# Start Spark session
spark = SparkSession.builder \
    .appName("MTA_Stream_Predict") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0") \
    .getOrCreate()
spark.sparkContext.setLogLevel("WARN")

:: loading settings :: url = jar:file:/opt/anaconda3/lib/python3.11/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /Users/gopalakrishnaabba/.ivy2/cache
The jars for the packages stored in: /Users/gopalakrishnaabba/.ivy2/jars
org.apache.spark#spark-sql-kafka-0-10_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-ee8048d6-150a-4a5f-8f5f-98adca0711ac;1.0
	confs: [default]
	found org.apache.spark#spark-sql-kafka-0-10_2.12;3.5.0 in central
	found org.apache.spark#spark-token-provider-kafka-0-10_2.12;3.5.0 in central
	found org.apache.kafka#kafka-clients;3.4.1 in central
	found org.lz4#lz4-java;1.8.0 in central
	found org.xerial.snappy#snappy-java;1.1.10.3 in central
	found org.slf4j#slf4j-api;2.0.7 in central
	found org.apache.hadoop#hadoop-client-runtime;3.3.4 in central
	found org.apache.hadoop#hadoop-client-api;3.3.4 in central
	found commons-logging#commons-logging;1.1.3 in central
	found com.google.code.findbugs#jsr305;3.0.0 in central
	found org.apache.commons#commons-pool2;2.11.1 in central
:: resolution report :: resolve 379ms 

- Connects to Kafka broker at `localhost:9092`.
- Subscribes to the topic named `mta_turnstile_topic`.
- Reads only new messages (latest offset).

In [6]:
# Step 1: Read Kafka Stream
kafka_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "mta_turnstile_topic") \
    .option("startingOffsets", "latest") \
    .load()

# Step 2: Parse Kafka value
turnstile_values = kafka_df.selectExpr("CAST(value AS STRING) as csv")
split_col = split(turnstile_values["csv"], ",")

# Step 3: Extract relevant column (STATION only)



station_df = turnstile_values.select(
    split_col.getItem(3).alias("STATION"),
    split_col.getItem(6).alias("DATE"),
    split_col.getItem(7).alias("TIME")
).withColumn(
    "timestamp", to_timestamp(concat_ws(" ", col("DATE"), col("TIME")), "MM/dd/yyyy HH:mm:ss")
).withColumn(
    "hour", hour(col("timestamp"))
).withColumn(
    "day_of_week", expr("extract(DAYOFWEEK FROM timestamp)")
).select("STATION", "hour", "day_of_week").select(["STATION", "hour", "day_of_week"])



- Converts Kafka’s byte stream to strings.

- Splits the CSV string into individual fields.

- Extracts relevant columns:

    `STATION` (index 3)

    `DATE` (index 6)

    `TIME` (index 7)

- Combines date and time into a proper timestamp.

- Derives hour and `day_of_week` features for prediction.

## Define Batch Prediction Function

- Loads pre-trained Random Forest models for predicting ENTRIES and EXITS.
- Applies them to each mini-batch of incoming data.
- Joins both predictions into a unified DataFrame.
- Prints results to the console.
- Starts streaming query on the parsed station data.
- Applies prediction function to each batch.
- Keeps running until manually terminated.

In [9]:
# Step 4: Define function to apply model in each batch

def predict_batch(batch_df, batch_id):
    # Load trained models
    entries_model = PipelineModel.load("/Users/gopalakrishnaabba/mta_rf_entries_model")
    exits_model   = PipelineModel.load("/Users/gopalakrishnaabba/mta_rf_exits_model")

    # Ensure correct feature columns are passed to both models
    input_cols = ["STATION", "hour", "day_of_week"]
    
    entries_pred = entries_model.transform(batch_df.select(*input_cols)) \
        .select("STATION", "hour", "day_of_week", "prediction") \
        .withColumnRenamed("prediction", "predicted_ENTRIES")
    
    exits_pred = exits_model.transform(batch_df.select(*input_cols)) \
        .select("STATION", "hour", "day_of_week", "prediction") \
        .withColumnRenamed("prediction", "predicted_EXITS")

    # Join predictions
    joined_pred = entries_pred.join(
        exits_pred, on=["STATION", "hour", "day_of_week"], how="inner"
    )

    joined_pred.show(truncate=False)


# Step 5: Run streaming query with foreachBatch
station_df.writeStream \
    .outputMode("update") \
    .foreachBatch(predict_batch) \
    .start() \
    .awaitTermination()

25/05/05 11:19:23 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /private/var/folders/48/hk2s0t4x1xv5gpns58fpsx1m0000gn/T/temporary-cd96399d-fd05-4554-b499-c690c79bdd70. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/05/05 11:19:23 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
25/05/05 11:19:24 WARN AdminClientConfig: These configurations '[key.deserializer, value.deserializer, enable.auto.commit, max.poll.records, auto.offset.reset]' were supplied but are not used yet.
                                                                                

+-------+----+-----------+-----------------+---------------+
|STATION|hour|day_of_week|predicted_ENTRIES|predicted_EXITS|
+-------+----+-----------+-----------------+---------------+
+-------+----+-----------+-----------------+---------------+





CodeCache: size=131072Kb used=38240Kb max_used=38247Kb free=92831Kb
 bounds [0x00000001048e4000, 0x0000000106e74000, 0x000000010c8e4000]
 total_blobs=14740 nmethods=13618 adapters=1032
 compilation: disabled (not enough contiguous free space left)


                                                                                

+-------------+----+-----------+-----------------+----------------+
|STATION      |hour|day_of_week|predicted_ENTRIES|predicted_EXITS |
+-------------+----+-----------+-----------------+----------------+
|TIME SQ-42 ST|11  |2          |8.327236983356865|8.85129257750241|
+-------------+----+-----------+-----------------+----------------+

+---------------+----+-----------+-----------------+-----------------+
|STATION        |hour|day_of_week|predicted_ENTRIES|predicted_EXITS  |
+---------------+----+-----------+-----------------+-----------------+
|34 ST-HERALD SQ|11  |2          |8.763883182670865|8.962312789392662|
+---------------+----+-----------+-----------------+-----------------+

+-------+----+-----------+-----------------+----------------+
|STATION|hour|day_of_week|predicted_ENTRIES|predicted_EXITS |
+-------+----+-----------+-----------------+----------------+
|23 ST  |11  |2          |8.439633063456519|8.69550202229375|
+-------+----+-----------+-----------------+-----------

                                                                                

+---------------+----+-----------+-----------------+-----------------+
|STATION        |hour|day_of_week|predicted_ENTRIES|predicted_EXITS  |
+---------------+----+-----------+-----------------+-----------------+
|WORLD TRADE CTR|11  |2          |8.327743794050125|8.711265964234162|
+---------------+----+-----------+-----------------+-----------------+

+-------+----+-----------+-----------------+----------------+
|STATION|hour|day_of_week|predicted_ENTRIES|predicted_EXITS |
+-------+----+-----------+-----------------+----------------+
|23 ST  |11  |2          |8.439633063456519|8.69550202229375|
+-------+----+-----------+-----------------+----------------+

+---------------+----+-----------+-----------------+-----------------+
|STATION        |hour|day_of_week|predicted_ENTRIES|predicted_EXITS  |
+---------------+----+-----------+-----------------+-----------------+
|34 ST-HERALD SQ|11  |2          |8.763883182670865|8.962312789392662|
+---------------+----+-----------+---------------

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.11/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 