In [25]:
from pyspark.sql import SparkSession

# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("MAST30034 Tutorial 1")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .getOrCreate()
)

In [26]:
sdf_all = spark.read.parquet('../data/raw/tlc_data/')

## Feature Engineering

In [27]:
from pyspark.sql.functions import col
from pyspark.sql.functions import unix_timestamp

sdf_all = sdf_all.withColumn("duration", unix_timestamp(col("tpep_dropoff_datetime"))-unix_timestamp(col("tpep_pickup_datetime")))
sdf_all.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|duration|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+
|       1| 2023-10-01 00:21:18|  2023-10-01 00:27:31|              1|          0.9|         1|                 N|         161|         186|           1|        

In [28]:
print("Removing rows with non-positive duration", sdf_all.filter((col('duration') <= 0)).count())
sdf_all = sdf_all.filter((col('duration') > 0))



Removing rows with non-positive duration 1238


                                                                                

In [29]:
sdf_all = sdf_all.withColumn("profit", col("Fare_amount") + col("Extra") + col("Tip_amount"))
sdf_all.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+-----------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|duration|           profit|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+-----------------+
|       1| 2023-10-01 00:21:18|  2023-10-01 00:27:31|              1|          0.9|         1|            

In [30]:
sdf_all = sdf_all.withColumn("profit_per_minute", col("profit") / col("duration") * 60)
sdf_all.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+-----------------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|duration|           profit| profit_per_minute|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+-----------------+------------------+
|       1| 2023-10-01 00:21:18|  2023-10-01 00:27

## Merging 

In [31]:
SAMPLE_SIZE = 0.05

In [32]:
# Sample the data for plotting
sample_df = sdf_all.sample(SAMPLE_SIZE, seed=0)
sample_df.write.mode('overwrite').parquet('../data/sample/tlc_data')

24/08/10 15:36:03 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/08/10 15:36:06 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                