# Setup

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, window, count, when, lit, row_number, current_timestamp, unix_timestamp, expr, from_json, udf
from pyspark.sql.types import StructType, StringType, DoubleType, TimestampType
from pyspark.sql import Window
import math

In [2]:
# Initialize SparkSession with the Kafka JAR
spark = SparkSession.builder \
    .appName("KafkaTaxiStream") \
    .getOrCreate()

print("✅ Spark Session created successfully!")

✅ Spark Session created successfully!


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 58962)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/conda/lib/python3.11/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/opt/conda/lib/python3.11/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/lib/python3.11/socketserver.py", line 755, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 295, in handle
    poll(accum_updates)
  File "/usr/local/spark/python/pyspark/accumulators.py", line 267, in poll
    if self.rfile in r and func():
                           ^^^^^^
  File "/usr/local/spark/python/pyspark/accumulators.py", line 271, in accum_updates
    num_updates =

In [3]:
# Define Schema for Incoming Data
schema = StructType() \
    .add("medallion", StringType()) \
    .add("hack_license", StringType()) \
    .add("pickup_datetime", TimestampType()) \
    .add("dropoff_datetime", TimestampType()) \
    .add("trip_time_in_secs", DoubleType()) \
    .add("trip_distance", DoubleType()) \
    .add("pickup_longitude", DoubleType()) \
    .add("pickup_latitude", DoubleType()) \
    .add("dropoff_longitude", DoubleType()) \
    .add("dropoff_latitude", DoubleType()) \
    .add("payment_type", StringType()) \
    .add("fare_amount", DoubleType()) \
    .add("surcharge", DoubleType()) \
    .add("mta_tax", DoubleType()) \
    .add("tip_amount", DoubleType()) \
    .add("tolls_amount", DoubleType()) \
    .add("total_amount", DoubleType())  # CSV has 17 fields; 17th not in task description, but on DEBS 2015 website


In [4]:
taxi_stream = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka:9092") \
    .option("subscribe", "taxi-trips") \
    .option("startingOffsets", "earliest") \
    .load()

# Parse the JSON data from Kafka
parsed_taxi_stream = taxi_stream.selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), schema).alias("data")) \
    .select("data.*")

# Query 1: Frequent Routes

## Part 1

In [5]:
# Define the reference point (Barryville cell 1.1 center)
reference_lat = 41.474937
reference_lon = -74.913585
cell_size_meters = 500
total_cells = 300  # 300x300 grid

# Function to calculate the cell based on coordinates
def calculate_cell(lat, lon):
    print("t1", lat, lon)
    return "1.1"
    # # Earth's radius in meters
    # earth_radius = 6371000  
    
    # # Calculate meters per degree of latitude (roughly constant)
    # meters_per_lat = 111320  # 111.32 km per degree of latitude
    
    # # Calculate meters per degree of longitude (varies with latitude)
    # meters_per_lon = earth_radius * math.cos(math.radians(reference_lat)) * 2 * math.pi / 360
    
    # # Calculate the cell offsets
    # lat_offset = (reference_lat - lat) * meters_per_lat / cell_size_meters
    # lon_offset = (lon - reference_lon) * meters_per_lon / cell_size_meters
    
    # # Calculate the cell coordinates (1-indexed)
    # cell_x = 1 + int(lon_offset)
    # cell_y = 1 + int(lat_offset)
    
    # # Ensure cells are within the valid range (1-300)
    # if cell_x < 1 or cell_x > total_cells or cell_y < 1 or cell_y > total_cells:
    #     return None
    
    # return f"{cell_x}.{cell_y}"

In [6]:
# Register the UDF (User Defined Function)
calculate_cell_udf = udf(calculate_cell, StringType())

In [7]:
taxi_with_cells = parsed_taxi_stream \
    .withColumn("start_cell", calculate_cell_udf(col("pickup_latitude"), col("pickup_longitude"))) \
    .withColumn("end_cell", calculate_cell_udf(col("dropoff_latitude"), col("dropoff_longitude"))) \
    .filter(col("start_cell").isNotNull() & col("end_cell").isNotNull()) \
    .withColumn("input_timestamp", current_timestamp())

In [8]:
# Group by route (start_cell, end_cell) and count occurrences
route_counts = taxi_with_cells \
    .withWatermark("dropoff_datetime", "10 minutes") \
    .filter(col("dropoff_datetime") >= current_timestamp() - expr("INTERVAL 30 MINUTES")) \
    .groupBy("start_cell", "end_cell") \
    .count() \
    .orderBy(col("count").desc()) \
    .limit(10)

# Output to console for debugging
query = route_counts \
    .writeStream \
    .outputMode("complete") \
    .format("console") \
    .option("truncate", "false") \
    .start()

In [9]:
#spark.streams.awaitAnyTermination()
query.awaitTermination(10)

False

=============================

In [10]:
exit()

In [9]:
# Method 1: Using display() function (for Databricks notebooks)
# This will show a live updating table in the notebook
display(
    taxi_stream.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", 
                          "topic", "partition", "offset", "timestamp")
)


def process_batch(df, epoch_id):
    # Only process if the DataFrame isn't empty
    if not df.isEmpty():
        # Show the batch in the notebook
        print(f"Batch {epoch_id} received at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        df.show(10, False)  # Show 10 rows, don't truncate columns
        
        # Optional: Count messages in this batch
        count = df.count()
        print(f"Message count in this batch: {count}")

# Use foreachBatch to process each micro-batch and show results
# query = taxi_stream \
#     .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "topic", "partition", "offset") \
#     .writeStream \
#     .foreachBatch(process_batch) \
#     .start()

def print_row(df):
    df.createOrReplaceTempView("latest_taxi_data")
    spark.sql("""
            SELECT 
                pickup_datetime, 
                COUNT(*) as trip_count
            FROM latest_taxi_data
            GROUP BY pickup_datetime
            ORDER BY pickup_datetime DESC
            LIMIT 10
        """).show()

query = taxi_stream \
    .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "topic", "partition", "offset") \
    .writeStream \
    .foreach(print_row) \
    .start()

# Let it run for some time
import time
time.sleep(15)  # Run for 60 seconds
query.stop()
# query.awaitTermination(15)

DataFrame[key: string, value: string, topic: string, partition: int, offset: bigint, timestamp: timestamp]

Traceback (most recent call last):
  File "/usr/local/spark/python/pyspark/serializers.py", line 459, in dumps
    return cloudpickle.dumps(obj, pickle_protocol)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/usr/local/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 632, in dump
    return Pickler.dump(self, obj)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/pyspark/context.py", line 466, in __getnewargs__
    raise PySparkRuntimeError(
pyspark.errors.exceptions.base.PySparkRuntimeError: [CONTEXT_ONLY_VALID_ON_DRIVER] It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.


PicklingError: Could not serialize object: PySparkRuntimeError: [CONTEXT_ONLY_VALID_ON_DRIVER] It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

In [None]:
# Read Data Stream from Kafka
taxi_stream = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka:9092") \
    .option("subscribe", "taxi-trips") \
    .option("startingOffsets", "earliest") \
    .load()

# Parse JSON Data from Kafka
parsed_data = taxi_stream.selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), schema).alias("data")) \
    .select("data.*")

# Display Streaming Data in Console
# query = parsed_data.writeStream \
#    .outputMode("append") \
#    .format("console") \
#    .start()



# query = parsed_data.writeStream \
#     .outputMode("append") \
#     .format("memory") \
#     .queryName("taxi_trips_table") \
#     .start()

query.awaitTermination(4)

In [None]:
payment_stats = parsed_data \
    .groupBy("payment_type") \
    .agg({"trip_distance": "avg"}) \
    .withColumnRenamed("avg(trip_distance)", "avg_trip_distance")

query = payment_stats \
    .writeStream \
    .outputMode("complete") \
    .format("console") \
    .start()

query.awaitTermination(5)
payment_stats.show()

In [None]:
spark.sql("SELECT * FROM taxi_trips_table LIMIT 10").show()

In [None]:
query.exception()

In [None]:
#query.awaitTermination()

In [None]:
spark.sql("SHOW TABLES").show()

In [None]:
parsed_data.isStreaming