In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, LongType
from delta import *
import os
import time
# !pip install delta-spark

In [2]:
# Create SparkSession with Delta Lake support
builder = SparkSession.builder.appName("SensorDataWindow") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .config("spark.driver.memory", "5g")

spark = configure_spark_with_delta_pip(builder).getOrCreate()

In [3]:
schema = StructType([
    StructField("medallion", StringType(), True),
    StructField("hack_license", StringType(), True),
    StructField("pickup_datetime", TimestampType(), True),
    StructField("dropoff_datetime", TimestampType(), True),
    StructField("trip_time_in_secs", IntegerType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("pickup_longitude", DoubleType(), True),
    StructField("pickup_latitude", DoubleType(), True),
    StructField("dropoff_longitude", DoubleType(), True),
    StructField("dropoff_latitude", DoubleType(), True),
    StructField("payment_type", StringType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("surcharge", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("total_amount", DoubleType(), True)
])

# Query 0

In [65]:
#original, unmodified dataset
rides_df = (spark.read
            .schema(schema)
            .csv("input/minified_sorted_data.csv") #path
            )

In [66]:
rides_df.printSchema()

root
 |-- medallion: string (nullable = true)
 |-- hack_license: string (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- trip_time_in_secs: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_longitude: double (nullable = true)
 |-- pickup_latitude: double (nullable = true)
 |-- dropoff_longitude: double (nullable = true)
 |-- dropoff_latitude: double (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- surcharge: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- total_amount: double (nullable = true)



In [67]:
# remove rows with null values
initial_count = rides_df.count()
rides_df = rides_df.dropna()

print(f"Removed {initial_count - rides_df.count()} lines")

Removed 0 lines


In [68]:
# check if medallion and hash license are valid md5, remove the row otherwise

MD5_PATTERN = r"^[a-fA-F0-9]{32}$"

initial_count = rides_df.count()

rides_df = rides_df.filter(
    col("medallion").rlike(MD5_PATTERN) &
    col("hack_license").rlike(MD5_PATTERN)
)

print(f"Removed {initial_count - rides_df.count()} lines")

Removed 0 lines


In [69]:
# remove rows with invalid pickup of dropoff times

initial_count = rides_df.count()

rides_df = rides_df.filter(
    col("dropoff_datetime") > col("pickup_datetime")
)

print(f"Removed {initial_count - rides_df.count()} lines")

Removed 266 lines


In [70]:
# check for illogical numeric values

initial_count = rides_df.count()

rides_df = rides_df.filter(
    (col("trip_time_in_secs") > 0) &
    (col("trip_distance") > 0) &
    (col("fare_amount") >= 0) &
    (col("surcharge") >= 0) &
    (col("mta_tax") >= 0) &
    (col("tip_amount") >= 0) &
    (col("tolls_amount") >= 0)
)

print(f"Removed {initial_count - rides_df.count()} lines")

Removed 319 lines


In [71]:
# remove lines with invalid fare calculation, eg where total_amount does not equal the sum of all fees

initial_count = rides_df.count()

rides_df = rides_df.filter(
    col("total_amount") ==
    (
        col("fare_amount") + 
        col("surcharge") + 
        col("mta_tax") + 
        col("tip_amount") + 
        col("tolls_amount")
    )
)

print(f"Removed {initial_count - rides_df.count()} lines")

Removed 845 lines


In [72]:
rides_df.count()

73570

In [73]:
rides_df_1000 = rides_df.limit(1000)

rides_df_1000.write.csv("input/rides_df_1000", header=True, mode="overwrite")

In [74]:
# Write the cleaned dataset to file
rides_df.write.csv("input/cleaned_minified_data", header=True, mode="overwrite")

# Query 1

In [19]:
# For inspecting the data
rides_df = (spark.read
            .schema(schema)
            .option("header", "true")
            .csv("input/rides_df_1000") #path
            )

In [20]:
rides_df.show(5, False)

+--------------------------------+--------------------------------+-------------------+-------------------+-----------------+-------------+----------------+---------------+-----------------+----------------+------------+-----------+---------+-------+----------+------------+------------+
|medallion                       |hack_license                    |pickup_datetime    |dropoff_datetime   |trip_time_in_secs|trip_distance|pickup_longitude|pickup_latitude|dropoff_longitude|dropoff_latitude|payment_type|fare_amount|surcharge|mta_tax|tip_amount|tolls_amount|total_amount|
+--------------------------------+--------------------------------+-------------------+-------------------+-----------------+-------------+----------------+---------------+-----------------+----------------+------------+-----------+---------+-------+----------+------------+------------+
|5EE2C4D3BF57BDB455E74B03B89E43A7|E96EF8F6E6122591F9465376043B946D|2013-01-01 00:00:09|2013-01-01 00:00:36|26               |0.1        

In [4]:
rides_stream = (spark.readStream
                .schema(schema)
                .option("header", "true")
                .csv("input/cleaned_minified_data") #path
                )

In [5]:
import math

# UDF for converting latitude and longitude to a grid cell ID
# Courtesy of Claude 3.7
def lat_long_to_grid(lat, long):
    # Barryville reference point (center of cell 1.1)
    reference_lat = 41.474937
    reference_long = -74.913585
    
    # Calculate distance from reference point
    # For latitude: 1 degree ~ 111 km (varies slightly with latitude)
    # Moving south means decreasing latitude
    lat_dist_km = (reference_lat - lat) * 111.0  # Distance south in km
    
    # For longitude: 1 degree ~ 111 * cos(latitude) km
    # Moving east means increasing longitude
    long_dist_km = (long - reference_long) * 111.0 * math.cos(math.radians(reference_lat))  # Distance east in km
    
    # Convert to meters
    lat_dist_m = lat_dist_km * 1000
    long_dist_m = long_dist_km * 1000
    
    # Check if outside the grid (more than 150km south or east from reference)
    if lat_dist_m < 0 or lat_dist_m > 150000 or long_dist_m < 0 or long_dist_m > 150000:
        return None
    
    # Calculate cell IDs
    # Cell 1.1 starts at reference point (center of the cell)
    # Each cell is 500m x 500m
    # To get the cell number, divide by 500 and add 1
    
    # For the first component (east direction)
    cell_east = int(long_dist_m / 500) + 1
    
    # For the second component (south direction)
    cell_south = int(lat_dist_m / 500) + 1
    
    # Cell ID as "east.south"
    return f"{cell_east}.{cell_south}"

# Register the UDF with Spark
lat_long_to_grid_udf = udf(lat_long_to_grid, StringType())

## Part 1

In [6]:
most_frequent_routes_query = (rides_stream
    .withColumn("start_cell", lat_long_to_grid_udf(col("pickup_latitude"), col("pickup_longitude")))
    .withColumn("end_cell", lat_long_to_grid_udf(col("dropoff_latitude"), col("dropoff_longitude")))
    .filter(col("start_cell").isNotNull() & col("end_cell").isNotNull()) # The udf returns None for invalid cells
    .groupBy(
        window(col("pickup_datetime"), "30 minutes"),
        col("start_cell"),
        col("end_cell")
    )
    .agg(count("*").alias("Number_of_Rides"))
    .orderBy(col("Number_of_Rides").desc())
    .limit(10)
)

In [7]:
# Function to create table if not exists
def create_table_if_exists(output_path, table_name):
    data_exists = False
    for _i in range(5):  # Retry for 60 seconds
        try:
            time.sleep(1)
            files = os.listdir(output_path)
            for _f in files:
                if ".parquet" in _f:
                    if len(os.listdir(f"{output_path}/_delta_log")) > 0:
                        print("data exists")
                        data_exists = True
                        break
            if data_exists:
                spark.sql(f"CREATE TABLE IF NOT EXISTS {table_name} USING DELTA LOCATION '{output_path}'")
                break
        except Exception as e:
            print(e)  # Uncomment if you want to see exceptions
            pass

In [11]:
checkpoint_path = "output/_checkpoint"
output_path = "output/most_frequent_routes"
os.makedirs(output_path, exist_ok=True)

table_name = "most_frequent_routes"
create_table_if_exists(output_path, table_name)

(most_frequent_routes_query.
    writeStream
    .format("delta")
    .outputMode("complete")
    .option("checkpointLocation", checkpoint_path)
    .queryName("most_frequent_routes")
    .trigger(processingTime="5 seconds")
    .start(output_path)
    .awaitTermination(timeout = 120)
)

data exists


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 718, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

In [10]:
# Check for active streams and stop them if they exist
for query in spark.streams.active:
    if query.name.startswith("most_frequent_routes"):
        print(f"Stopping existing query: {query.name}")
        query.stop()

Stopping existing query: most_frequent_routes


In [77]:
df = spark.read.format("delta").load(output_path)
df.toPandas().sort_values(by=['window'], ascending=True).head(10)

Unnamed: 0,window,start_cell,end_cell,Number_of_Rides
2711708,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",155.166,173.161,1
2320583,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",157.163,152.168,1
4136225,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",160.156,157.162,3
3103709,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",162.152,158.149,1
3378809,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",151.166,152.166,2
3103654,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",157.159,156.166,1
338437,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",157.16,162.16,1
2552291,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",151.163,160.147,1
3570852,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",151.163,158.169,2
2320450,"(2013-01-01 00:00:00, 2013-01-01 00:30:00)",160.155,161.153,1


## Part 2

In [101]:
checkpoint_path_2 = "output/_checkpoint_part2"
output_path_2 = "output/most_frequent_routes_part2"
os.makedirs(output_path_2, exist_ok=True)
os.makedirs(checkpoint_path_2, exist_ok=True)

table_name_2 = "most_frequent_routes_part2"
create_table_if_exists(output_path_2, table_name_2)

In [102]:
rides_stream_2 = (spark.readStream
                .schema(schema)
                .option("header", "true")
                .option("maxFilesPerTrigger", 1) #process one file at a time
                .csv("input/cleaned_minified_data"))

# Include the ingestion time
rides_stream_2 = rides_stream_2.withColumn("ingestion_time", current_timestamp())

In [103]:
# Convert pickup and dropoff to grid IDs
rides_stream_2 = (rides_stream_2
                .withColumn("start_cell", lat_long_to_grid_udf(col("pickup_latitude"), col("pickup_longitude")))
                .withColumn("end_cell", lat_long_to_grid_udf(col("dropoff_latitude"), col("dropoff_longitude")))
                .filter(col("start_cell").isNotNull() & col("end_cell").isNotNull())
)

In [104]:
# Add windowing and group together rides on the same route
rides_stream_2 = (rides_stream_2
                .groupBy(
                    window(col("pickup_datetime"), "30 minutes"),
                    col("start_cell"),
                    col("end_cell")
                )
                .agg(
                    count("*").alias("Number_of_Rides"), 
                    min(col("ingestion_time")).alias("ingestion_time")
                )
                .orderBy(col("Number_of_Rides").desc())
)

In [105]:
result_schema = StructType([
    StructField("pickup_datetime", TimestampType(), True),
    StructField("dropoff_datetime", TimestampType(), True),
    StructField("start_cell_id_1", StringType(), True),
    StructField("end_cell_id_1", StringType(), True),
    StructField("start_cell_id_2", StringType(), True),
    StructField("end_cell_id_2", StringType(), True),
    StructField("start_cell_id_3", StringType(), True),
    StructField("end_cell_id_3", StringType(), True),
    StructField("start_cell_id_4", StringType(), True),
    StructField("end_cell_id_4", StringType(), True),
    StructField("start_cell_id_5", StringType(), True),
    StructField("end_cell_id_5", StringType(), True),
    StructField("start_cell_id_6", StringType(), True),
    StructField("end_cell_id_6", StringType(), True),
    StructField("start_cell_id_7", StringType(), True),
    StructField("end_cell_id_7", StringType(), True),
    StructField("start_cell_id_8", StringType(), True),
    StructField("end_cell_id_8", StringType(), True),
    StructField("start_cell_id_9", StringType(), True),
    StructField("end_cell_id_9", StringType(), True),
    StructField("start_cell_id_10", StringType(), True),
    StructField("end_cell_id_10", StringType(), True),
    StructField("delay", LongType(), True)
])

previous_top10 = None

def process_batch(batch_df, batch_id):
    global previous_top10
    
    if batch_df.isEmpty():
        return
    
    # Compare the current batch with the previous one
    current_top10 = batch_df.orderBy(col("Number_of_Rides").desc()).limit(10).collect()

    if previous_top10 is None or current_top10 != previous_top10:
        previous_top10 = current_top10

        window_data = batch_df.select("window").first()
        if window_data:
            # Create a row with explicit types
            result = spark.createDataFrame(
                [(
                    window_data["window"].start,  # pickup_datetime
                    window_data["window"].end,    # dropoff_datetime
                    None, None, None, None, None, None, None, None, None, None,
                    None, None, None, None, None, None, None, None, None, None,
                    0  # placeholder for delay
                )], 
                schema=result_schema  # Use the explicit schema
            )
            
            # Populate the result with actual top 10 routes
            for i, route in enumerate(current_top10):
                result = result.withColumn(f"start_cell_id_{i+1}", lit(route["start_cell"]))
                result = result.withColumn(f"end_cell_id_{i+1}", lit(route["end_cell"]))
            
            # Calculate delay in milliseconds
            min_time_row = batch_df.agg(min("ingestion_time").alias("min_time")).first()
            if min_time_row and min_time_row["min_time"]:
                earliest_processing_time = int(min_time_row["min_time"].timestamp())
                current_time = int(time.time())
                delay_ms = (current_time - earliest_processing_time) * 1000
                result = result.withColumn("delay", lit(delay_ms))
            
            # Write to Delta table
            result.write.format("delta").mode("append").save(output_path_2)

In [106]:
for query in spark.streams.active:
    if query.name.startswith("most_frequent_routes"):
        print(f"Stopping existing query: {query.name}")
        query.stop()

(rides_stream_2
    .writeStream
    .format("delta")
    .foreachBatch(process_batch)
    .outputMode("complete")
    .option("checkpointLocation", checkpoint_path_2)
    .option("mergeSchema", "true")
    .queryName("most_frequent_routes_part2")
    .trigger(processingTime="5 seconds")
    .start(output_path_2)
    .awaitTermination(timeout = 120)
)

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 718, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

In [43]:
results_df = spark.read.format("delta").load(output_path_2)

results_df.show(truncate=False)

+-------------------+-------------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+----------------+--------------+------+
|pickup_datetime    |dropoff_datetime   |start_cell_id_1|end_cell_id_1|start_cell_id_2|end_cell_id_2|start_cell_id_3|end_cell_id_3|start_cell_id_4|end_cell_id_4|start_cell_id_5|end_cell_id_5|start_cell_id_6|end_cell_id_6|start_cell_id_7|end_cell_id_7|start_cell_id_8|end_cell_id_8|start_cell_id_9|end_cell_id_9|start_cell_id_10|end_cell_id_10|delay |
+-------------------+-------------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-------------+---------------+-

In [107]:
import time
import pandas as pd

for query in spark.streams.active:
    if query.name.startswith("most_frequent_routes"):
        print(f"Stopping existing query: {query.name}")
        query.stop()

# Start the stream in a separate cell (don't use awaitTermination)
streaming_query = (rides_stream_2
    .writeStream
    .foreachBatch(process_batch)
    .outputMode("complete")
    .option("checkpointLocation", checkpoint_path_2)
    .queryName("most_frequent_routes_part2")
    .trigger(processingTime="5 seconds")
    .start(output_path_2)
)

# Then in another cell, run this monitoring loop:
try:
    # Track what records we've already seen
    seen_records = set()
    
    print("Starting real-time monitoring")
    
    while streaming_query.isActive:
        # Read the latest results
        results = spark.read.format("delta").load(output_path_2)
        
        # Create a unique identifier for each record (combination of timestamps)
        results = results.withColumn(
            "record_id", 
            concat(col("pickup_datetime").cast("string"), col("dropoff_datetime").cast("string"))
        )
        
        # Convert to pandas for easier handling
        results_pd = results.toPandas()
        
        # Check for new records
        for _, row in results_pd.iterrows():
            record_id = row['record_id']
            if record_id not in seen_records:
                # New record found!
                seen_records.add(record_id)
                
                # Print the new record details
                print("\n" + "="*50)
                print(f"NEW UPDATE DETECTED at {time.strftime('%H:%M:%S')}")
                print(f"Window: {row['pickup_datetime']} to {row['dropoff_datetime']}")
                print(f"Processing delay: {row['delay']} ms")
                print("-"*50)
                print("Top 10 Routes:")
                
                for i in range(1, 11):
                    start = row[f'start_cell_id_{i}']
                    end = row[f'end_cell_id_{i}']
                    if pd.notna(start) and pd.notna(end):
                        print(f"  #{i}: {start} → {end}")
                    else:
                        print(f"  #{i}: No data")
                print("="*50)
        
        # Wait 5 seconds before checking again
        time.sleep(5)
        
except KeyboardInterrupt:
    print("\nMonitoring stopped.")
    
finally:
    # Optional: You can stop the query here or let it run
    # streaming_query.stop()
    pass

Stopping existing query: most_frequent_routes_part2
Starting real-time monitoring

NEW UPDATE DETECTED at 17:53:36
Window: 2013-01-08 06:30:00 to 2013-01-08 07:00:00
Processing delay: 21000 ms
--------------------------------------------------
Top 10 Routes:
  #1: 154.160 → 155.159
  #2: 154.160 → 155.159
  #3: 154.160 → 156.160
  #4: 154.160 → 156.160
  #5: 154.160 → 155.159
  #6: 154.160 → 156.160
  #7: 154.160 → 156.159
  #8: 154.160 → 156.159
  #9: 154.160 → 156.160
  #10: 154.160 → 156.160

NEW UPDATE DETECTED at 17:53:36
Window: 2013-01-17 07:30:00 to 2013-01-17 08:00:00
Processing delay: 12000 ms
--------------------------------------------------
Top 10 Routes:
  #1: 154.160 → 156.160
  #2: 154.160 → 155.159
  #3: 154.160 → 156.160
  #4: 154.160 → 155.159
  #5: 154.161 → 156.161
  #6: 154.160 → 156.160
  #7: 159.157 → 157.160
  #8: 154.160 → 155.159
  #9: 155.163 → 156.161
  #10: 154.160 → 156.159


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 718, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt



Monitoring stopped.


# Query 2

## Part 1

In [75]:
# RE-RUN QUERY 0 TO RE-INIT THE CLEANED DATASET

# For inspecting the data
rides_df = (spark.read
            .schema(schema)
            .option("header", "true")
            .csv("input/cleaned_minified_data") #path
            )

rides_df.show(5, False)

+--------------------------------+--------------------------------+-------------------+-------------------+-----------------+-------------+----------------+---------------+-----------------+----------------+------------+-----------+---------+-------+----------+------------+------------+
|medallion                       |hack_license                    |pickup_datetime    |dropoff_datetime   |trip_time_in_secs|trip_distance|pickup_longitude|pickup_latitude|dropoff_longitude|dropoff_latitude|payment_type|fare_amount|surcharge|mta_tax|tip_amount|tolls_amount|total_amount|
+--------------------------------+--------------------------------+-------------------+-------------------+-----------------+-------------+----------------+---------------+-----------------+----------------+------------+-----------+---------+-------+----------+------------+------------+
|9C07428094868EDE6CCC840C0332EE34|9C9DB7B440AACF2D056E19B784B0AA3F|2013-01-01 01:28:00|2013-01-01 01:49:00|1260             |3.77       

In [76]:
rides_stream_3 = (spark.readStream
                .schema(schema)
                .option("header", "true")
                .csv("input/cleaned_minified_data") #path
                )

In [97]:
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import time

# Reuse the grid conversion UDF
lat_long_to_grid_udf = udf(lat_long_to_grid, StringType())

# Create output directories
os.makedirs("output/profitable_areas", exist_ok=True)
os.makedirs("output/empty_taxis_temp", exist_ok=True)
os.makedirs("output/_checkpoint_profitable_areas", exist_ok=True)
os.makedirs("output/_checkpoint_empty_taxis", exist_ok=True)

# Init Delta tables
# Create empty taxis table schema
empty_taxis_schema = StructType([
    StructField("empty_window", StructType([
        StructField("start", TimestampType()),
        StructField("end", TimestampType())
    ])),
    StructField("cell_id", StringType()),
    StructField("empty_taxis_count", LongType())
])
# Create empty taxis table
if not os.path.exists("output/empty_taxis_temp/_delta_log"):
    spark.createDataFrame([], empty_taxis_schema) \
         .write.format("delta").save("output/empty_taxis_temp")

# Stream + Calculate profit metrics per cell
profit_per_cell = (
    rides_stream_3
    .withColumn("pickup_cell", lat_long_to_grid_udf(col("pickup_latitude"), col("pickup_longitude")))
    .withColumn("dropoff_cell", lat_long_to_grid_udf(col("dropoff_latitude"), col("dropoff_longitude")))
    .filter(col("pickup_cell").isNotNull() & col("dropoff_cell").isNotNull())
    .withColumn("trip_profit", col("fare_amount") + col("tip_amount"))
    .withWatermark("dropoff_datetime", "30 minutes")
    .groupBy(
        window(col("dropoff_datetime"), "15 minutes").alias("profit_window"),
        col("pickup_cell")
    )
    .agg(
        approx_percentile("trip_profit", 0.5).alias("median_profit"),
        count("*").alias("trips_count")
    )
)

# Calc Empty Taxis
def calculate_empty_taxis(batch_df, batch_id):
    batch_df = (
        batch_df
        .withColumn("pickup_cell", lat_long_to_grid_udf(col("pickup_latitude"), col("pickup_longitude")))
        .withColumn("dropoff_cell", lat_long_to_grid_udf(col("dropoff_latitude"), col("dropoff_longitude")))
        .filter(col("dropoff_cell").isNotNull())
    )
    
    # Create windows for each taxi
    window_spec = Window.partitionBy("medallion", "hack_license").orderBy("dropoff_datetime")
    
    # Find next pickup after each dropoff
    batch_df = (
        batch_df
        .withColumn("next_pickup", lead("pickup_datetime").over(window_spec))
        .withColumn("next_pickup_diff", 
            when(col("next_pickup").isNotNull(), 
                 unix_timestamp(col("next_pickup")) - unix_timestamp(col("dropoff_datetime")))
        .otherwise(lit(None)))
    )
    
    # ID empty taxis (empty at any point in the last 15 minutes)
    empty_taxis = (
        batch_df
        .filter(
            (col("next_pickup").isNull()) |
            ((col("next_pickup_diff") > 900) & (col("next_pickup_diff") <= 1800)) # Empty >= 15 mins <= 30
        )
        .groupBy(
            window(col("dropoff_datetime"), "15 minutes").alias("empty_window"),
            col("dropoff_cell").alias("cell_id")
        )
        .agg(count("*").alias("empty_taxis_count"))
    )
    
    # Keep relevant empty_taxi_data (last 30 minutes)
    empty_taxis = empty_taxis.withWatermark("empty_window", "30 minutes")
    
    # write
    empty_taxis.write.format("delta").mode("append").save("output/empty_taxis_temp")


# Start empty taxis stream
empty_taxis_stream = (
    rides_stream_3
    .writeStream
    .foreachBatch(calculate_empty_taxis)
    .outputMode("update")
    .option("checkpointLocation", "output/_checkpoint_empty_taxis")
    .start()
)

def calculate_profitability(batch_df, batch_id):
    try:
        empty_taxis = spark.read.format("delta").load("output/empty_taxis_temp")
        
        profitability = (
            batch_df
            .join(
                empty_taxis,
                (col("profit_window") == col("empty_window")) &
                (col("pickup_cell") == col("cell_id")),
                "inner"
            )
            .withColumn("profitability", 
                col("median_profit") / greatest(col("empty_taxis_count"), lit(1))
            )
            .select(
                col("profit_window.start").alias("analysis_window_start"),
                col("profit_window.end").alias("analysis_window_end"),
                col("pickup_cell").alias("profitable_cell_id"),
                col("empty_taxis_count").alias("empty_taxis_in_cell"),
                col("median_profit").alias("median_profit_in_cell"),
                col("profitability").alias("profitability_of_cell"),
                col("trips_count").alias("trips_count")
            )
            .orderBy(col("profitability_of_cell").desc())
            .limit(10)
        )
        
        profitability.write.format("delta").mode("overwrite").save("output/profitable_areas")
    except Exception as e:
        print(f"Error processing batch {batch_id}: {str(e)}")
        raise e


# Final Output Stream
profitability_stream = (
    profit_per_cell
    .writeStream
    .foreachBatch(calculate_profitability)
    .outputMode("complete")
    .option("checkpointLocation", "output/_checkpoint_profitable_areas")
    .trigger(processingTime="15 minutes")
    .start()
)

# Handle stream interruption gracefully
try:
    profitability_stream.awaitTermination(300)
except KeyboardInterrupt:
    print("Stopping streams gracefully...")
    empty_taxis_stream.stop()
    profitability_stream.stop()
    time.sleep(5)


In [98]:
profitability_df = spark.read.format("delta").load("output/profitable_areas")
display(profitability_df.orderBy(col("analysis_window_start").desc()))

DataFrame[analysis_window_start: timestamp, analysis_window_end: timestamp, profitable_cell_id: string, empty_taxis_in_cell: bigint, median_profit_in_cell: double, profitability_of_cell: double, trips_count: bigint]

In [99]:
profitability_pd = profitability_df.orderBy(col("profitability_of_cell").desc()).toPandas()
display(profitability_pd)

Unnamed: 0,analysis_window_start,analysis_window_end,profitable_cell_id,empty_taxis_in_cell,median_profit_in_cell,profitability_of_cell,trips_count
0,2013-01-01 00:30:00,2013-01-01 00:45:00,188.126,1,100.4,100.4,4
1,2013-01-01 00:30:00,2013-01-01 00:45:00,188.126,1,100.4,100.4,4
2,2013-01-01 01:15:00,2013-01-01 01:30:00,147.161,1,87.5,87.5,4
3,2013-01-01 01:15:00,2013-01-01 01:30:00,147.161,1,87.5,87.5,4
4,2013-01-01 01:30:00,2013-01-01 01:45:00,147.162,1,78.0,78.0,4
5,2013-01-01 01:00:00,2013-01-01 01:15:00,146.164,1,77.4,77.4,4
6,2013-01-01 01:00:00,2013-01-01 01:15:00,146.164,1,77.4,77.4,4
7,2013-01-01 01:45:00,2013-01-01 02:00:00,147.164,1,77.0,77.0,4
8,2013-01-01 01:45:00,2013-01-01 02:00:00,147.164,1,77.0,77.0,4
9,2013-01-01 01:30:00,2013-01-01 01:45:00,147.161,1,75.0,75.0,4


## Part 2

In [107]:
from datetime import datetime # Stick w/ 1 date library next time (-_-)

def transform_to_wide_format(batch_df, batch_id):
    start_time = datetime.now()
    
    try:
        # Read the profitability data + grab headers
        profitability_df = spark.read.format("delta").load("output/profitable_areas")

        latest_record = profitability_df.orderBy(col("analysis_window_start").desc()).first()
        
        if latest_record is None:
            return
        
        # Create a window for ranking
        window_spec = Window.orderBy(col("profitability_of_cell").desc())
        
        # Rank +pivot data
        ranked_df = (
            profitability_df
            .withColumn("rank", rank().over(window_spec))
            .filter(col("rank") <= 10)  # Only keep top 10
        )
        
        # Append 10 columns per type + format column names
        wide_df = (
            ranked_df
            .groupBy("analysis_window_start", "analysis_window_end")
            .pivot("rank", range(1, 11))
            .agg(
                first("profitable_cell_id").alias("profitable_cell_id"),
                first("empty_taxis_in_cell").alias("empty_taxis_in_cell"),
                first("median_profit_in_cell").alias("median_profit_in_cell"),
                first("profitability_of_cell").alias("profitability_of_cell")
            )
        )

        for i in range(1, 11):
            wide_df = (
                wide_df
                .withColumnRenamed(f"{i}_profitable_cell_id", f"profitable_cell_id_{i}")
                .withColumnRenamed(f"{i}_empty_taxis_in_cell", f"empty_taxis_in_cell_{i}")
                .withColumnRenamed(f"{i}_median_profit_in_cell", f"median_profit_in_cell_{i}")
                .withColumnRenamed(f"{i}_profitability_of_cell", f"profitability_of_cell_{i}")
            )
        
        # Fill NULLs for ranks that don't exist
        for i in range(1, 11):
            for col_suffix in ["profitable_cell_id", "empty_taxis_in_cell", 
                              "median_profit_in_cell", "profitability_of_cell"]:
                col_name = f"{col_suffix}_{i}"
                if col_name not in wide_df.columns:
                    wide_df = wide_df.withColumn(col_name, lit(None))
        
        # Calculate processing delay + add delay column
        processing_time = datetime.now()
        delay_seconds = (processing_time - start_time).total_seconds()
        
        wide_df = wide_df.withColumn("delay", lit(delay_seconds))

        # write
        wide_df.write.format("delta").mode("overwrite").save("output/profitable_areas_wide_format")
        
    except Exception as e:
        print(f"Error processing batch {batch_id}: {str(e)}")
        raise e

# Create the output stream
wide_format_stream = (
    profit_per_cell  # Same profit stream from Query 2 Part 1
    .writeStream
    .foreachBatch(transform_to_wide_format)
    .outputMode("update")
    .option("checkpointLocation", "output/profitable_areas_part2")
    .trigger(processingTime="1 minute")
    .start()
)


In [106]:
part2_df = spark.read.format("delta").load("output/profitable_areas_part2")
display(part2_df.orderBy(col("analysis_window_start").desc()))

DataFrame[analysis_window_start: timestamp, analysis_window_end: timestamp, 1_profitable_cell_id: string, 1_empty_taxis_in_cell: bigint, 1_median_profit_in_cell: double, 1_profitability_of_cell: double, 2_profitable_cell_id: string, 2_empty_taxis_in_cell: bigint, 2_median_profit_in_cell: double, 2_profitability_of_cell: double, 3_profitable_cell_id: string, 3_empty_taxis_in_cell: bigint, 3_median_profit_in_cell: double, 3_profitability_of_cell: double, 4_profitable_cell_id: string, 4_empty_taxis_in_cell: bigint, 4_median_profit_in_cell: double, 4_profitability_of_cell: double, 5_profitable_cell_id: string, 5_empty_taxis_in_cell: bigint, 5_median_profit_in_cell: double, 5_profitability_of_cell: double, 6_profitable_cell_id: string, 6_empty_taxis_in_cell: bigint, 6_median_profit_in_cell: double, 6_profitability_of_cell: double, 7_profitable_cell_id: string, 7_empty_taxis_in_cell: bigint, 7_median_profit_in_cell: double, 7_profitability_of_cell: double, 8_profitable_cell_id: string, 8_emp

In [108]:
from IPython.display import display, HTML
from pyspark.sql.functions import col

# Read wide_data
part2_delta_df = spark.read.format("delta").load("output/profitable_areas_part2")
pdf = part2_delta_df.orderBy(col("analysis_window_start").desc()).limit(20).toPandas()

# Display
display(HTML(pdf.to_html()))

Unnamed: 0,analysis_window_start,analysis_window_end,1_profitable_cell_id,1_empty_taxis_in_cell,1_median_profit_in_cell,1_profitability_of_cell,2_profitable_cell_id,2_empty_taxis_in_cell,2_median_profit_in_cell,2_profitability_of_cell,3_profitable_cell_id,3_empty_taxis_in_cell,3_median_profit_in_cell,3_profitability_of_cell,4_profitable_cell_id,4_empty_taxis_in_cell,4_median_profit_in_cell,4_profitability_of_cell,5_profitable_cell_id,5_empty_taxis_in_cell,5_median_profit_in_cell,5_profitability_of_cell,6_profitable_cell_id,6_empty_taxis_in_cell,6_median_profit_in_cell,6_profitability_of_cell,7_profitable_cell_id,7_empty_taxis_in_cell,7_median_profit_in_cell,7_profitability_of_cell,8_profitable_cell_id,8_empty_taxis_in_cell,8_median_profit_in_cell,8_profitability_of_cell,9_profitable_cell_id,9_empty_taxis_in_cell,9_median_profit_in_cell,9_profitability_of_cell,10_profitable_cell_id,10_empty_taxis_in_cell,10_median_profit_in_cell,10_profitability_of_cell,delay
0,2013-01-01 01:45:00,2013-01-01 02:00:00,,,,,,,,,,,,,,,,,147.164,1.0,77.0,77.0,,,,,,,,,,,,,,,,,,,,,0.96011
1,2013-01-01 01:30:00,2013-01-01 01:45:00,,,,,,,,,147.162,1.0,78.0,78.0,,,,,,,,,147.161,1.0,75.0,75.0,,,,,,,,,,,,,,,,,0.96011
2,2013-01-01 01:15:00,2013-01-01 01:30:00,,,,,147.161,1.0,87.5,87.5,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0.96011
3,2013-01-01 01:00:00,2013-01-01 01:15:00,,,,,,,,,,,,,146.164,1.0,77.4,77.4,,,,,,,,,,,,,,,,,,,,,,,,,0.96011
4,2013-01-01 00:30:00,2013-01-01 00:45:00,188.126,1.0,100.4,100.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0.96011
