In [45]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import unix_timestamp
from pyspark.sql.functions import hour, dayofweek, month

In [46]:
# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("MAST30034 Tutorial 1")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "15g")
    .getOrCreate()
)

In [47]:
tlc_all = spark.read.parquet('../data/raw/tlc_data/')

## Feature Engineering

In [48]:
# Add column for hour of day and day of week
tlc_all = tlc_all.withColumn("pickup_hour_of_day", hour("pickup_datetime"))
tlc_all = tlc_all.withColumn("pickup_day_of_week", dayofweek("pickup_datetime"))
tlc_all = tlc_all.withColumn("pickup_month", month("pickup_datetime"))
tlc_all.show(5)

+-----------------+--------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+-----------+----+----------+-------------------+-----------------+------------------+----------------+--------------+------------------+------------------+------------+
|hvfhs_license_num|dispatching_base_num|   request_datetime|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|trip_miles|trip_time|base_passenger_fare|tolls| bcf|sales_tax|congestion_surcharge|airport_fee|tips|driver_pay|shared_request_flag|shared_match_flag|access_a_ride_flag|wav_request_flag|wav_match_flag|pickup_hour_of_day|pickup_day_of_week|pickup_month|
+-----------------+--------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+-----------+----+----------+-------

In [49]:
# Add column for waiting time
tlc_all = tlc_all.withColumn("waiting_time", unix_timestamp("pickup_datetime") - unix_timestamp("request_datetime"))

In [50]:
zones = spark.read.csv("../data/taxi_zones/taxi+_zone_lookup.csv", header=True)
zones.show(10)

+----------+-------------+--------------------+------------+
|LocationID|      Borough|                Zone|service_zone|
+----------+-------------+--------------------+------------+
|         1|          EWR|      Newark Airport|         EWR|
|         2|       Queens|         Jamaica Bay|   Boro Zone|
|         3|        Bronx|Allerton/Pelham G...|   Boro Zone|
|         4|    Manhattan|       Alphabet City| Yellow Zone|
|         5|Staten Island|       Arden Heights|   Boro Zone|
|         6|Staten Island|Arrochar/Fort Wad...|   Boro Zone|
|         7|       Queens|             Astoria|   Boro Zone|
|         8|       Queens|        Astoria Park|   Boro Zone|
|         9|       Queens|          Auburndale|   Boro Zone|
|        10|       Queens|        Baisley Park|   Boro Zone|
+----------+-------------+--------------------+------------+


                                                                                

### Aggregating the movement data, to use in movement_plot.ipynb

In [51]:
# Merge the data with the shapefile
tlc_all = tlc_all.alias("tlc")\
    .join(zones.alias("zone"), tlc_all.PULocationID == zones.LocationID, how='left')\
    .select("tlc.*", "zone.Borough")\
    .withColumnRenamed("Borough", "pickup_borough")
tlc_all = tlc_all.alias("tlc")\
    .join(zones.alias("zone"), tlc_all.DOLocationID == zones.LocationID, how='left')\
    .select("tlc.*", "zone.Borough")\
    .withColumnRenamed("Borough", "dropoff_borough")

In [52]:
movement_aggregates = tlc_all.groupBy('pickup_borough', 'dropoff_borough', 'pickup_hour_of_day').agg({
    '*': 'count',
})
movement_aggregates = movement_aggregates.withColumnRenamed('count(1)', 'num_trips')

In [53]:
movement_aggregates.write.mode('overwrite').parquet('../data/movement_aggregates')

ERROR:root:KeyboardInterrupt while sending command.                (0 + 8) / 23]
Traceback (most recent call last):
  File "/Users/dakshagrawal/Documents/GitHub/project-1-individual-dakshAg-v2/.venv/lib/python3.12/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/dakshagrawal/Documents/GitHub/project-1-individual-dakshAg-v2/.venv/lib/python3.12/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socket.py", line 707, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt

KeyboardInterrupt: 

In [None]:
tlc_gdf = tlc_all.join(

In [None]:
geoJSON = gdf[['LocationID', 'geometry']].drop_duplicates('LocationID').to_json()
print(geoJSON[:300])

In [None]:
tlc_aggregated = tlc_all.groupBy('PULocationID','pickup_hour_of_day', 'pickup_day_of_week', 'hvfhs_license_num', "pickup_month").agg(
    {
        'trip_miles': 'mean',
        'trip_time': 'mean',
        'base_passenger_fare': 'mean',
        'driver_pay': 'mean',
        '*': 'count'
    }
).withColumnRenamed('count(1)', 'num_trips')
print(tlc_aggregated.count())

In [None]:
from pyspark.sql.functions import col

tlc_aggregated = tlc_aggregated.join(
    spark.createDataFrame(proportions.reset_index()).withColumn("PULocationID", col("LocationID").cast("int")),
    on='PULocationID',
    how='left')


In [None]:
# tlc_all.write.mode('overwrite').parquet('../data/curated/tlc_data')

In [None]:
SAMPLE_SIZE = 0.05

In [None]:
# Sample the data for plotting
sample_df = tlc_all.sample(SAMPLE_SIZE, seed=0)
sample_df.write.mode('overwrite').parquet('../data/sample/tlc_data')