In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType

#intialising the SparkSession
#using maxPartitionBytes helps us to control parallesiom. if not using this spark may read huge chunks of data and may cause outOfMemory(OOM)
Spark = SparkSession.builder\
        .appName("NYC_Taxi_Analytics")\
        .config("spark.sql.files.maxPartitionBytes", "128m")\
        .getOrCreate()




In [0]:
#defining the schema explicitly
#It helps the spark not to read whole data for schema and helps in performance 
taxi_schema = StructType([
    StructField("VendorID", IntegerType(), True),
    StructField("tpep_pickup_datetime", TimestampType(), True),
    StructField("tpep_dropoff_datetime", TimestampType(), True),
    StructField("passenger_count", DoubleType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("rateCodeId", IntegerType(), True),
    StructField("store_and_fwd_flag", StringType(), True),
    StructField("payment_type", IntegerType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True)
])


In [0]:
taxi_df = spark.read\
    .schema(taxi_schema)\
        .option("mode","DROPMALFORMED")\
            .csv("dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2019-01.csv.gz")
print(f"Initial Count : {taxi_df.count()}")

#For df.count() spark doesnot pull all data at option. if it does it will crash memory. instead, it peforms a partial sum 
#As we mentioned maxPartiitonsBytes is 128m
#Driver gives instrcution to executors to count rows in your slice.
#After counting executors do not send all rows back to driver. it will send only that integer count to driver in his(executor) slice
#In reading CSV. spark must read all data to find new line characters(\n) to count rows.
#When it is parquet file spark will skip data and reads the metadata at footer of file. it will complete it in milliseconds
#As we are using "DROPMALFORMED" mode. it must physically scan the data to see if any rows are malformed (corrupt).

In [0]:
from pyspark.sql.functions import col, month, to_date, when, lit
# Cleaning garbage data (Negative fares, 0 distance)
# LOGIC: A trip with 0 distance but > $0 fare is likely valid (cancellation fee), but we exclude for this analysis.
taxi_clean_df = taxi_df.filter(
    (col('fare_amount')>0) & (col('trip_distance')>0)
)



In [0]:
#transformations on the data, calculating trip_duration_seconds, revenue_per_mile, is_high_value
taxi_df_transformed = taxi_clean_df\
    .withColumn("trip_duration_seconds",col('tpep_dropoff_datetime').cast('long') - col('tpep_pickup_datetime').cast('long'))\
        .withColumn("revenue_per_mile",col('total_amount')/col('trip_distance'))\
            .withColumn("is_high_value",when(col('total_amount')> 50, lit('HIGH')).otherwise(lit('STANDARD')))

#Lazy evaluation : when we run above lines spark will build only logical map, it doesn't run any code.
# Catalyst optimiser = when you do filter . spark sends your fare>0 to the file reader. it pulls only good rows into the memory. bad rows leave it harddrive.
#whole_stage_code_generation = spark doesn't run line by line in python as it is slow. it compilrs 3-4 lines of python code into one single optimized java function(Bytecode). it eliminates the 'object overhead' of python and runs raw CPU instructions on binary data. 

In [0]:
# Broadcast join
# Why? Zone names are small (<10MB). Broadcasting avoids Shuffling.
zone_schema = "LocationID INT, Borough STRING, zone STRING, service_zone STRING"
zones_df = spark.read.csv("dbfs:/databricks-datasets/nyctaxi/taxizone/taxi_zone_lookup.csv", schema = zone_schema, header = True)

from pyspark.sql.functions import broadcast

#joinging using broadcast
taxi_joined_df = taxi_df_transformed.join(
    broadcast(zones_df),
    taxi_df_transformed.PULocationID == zones_df.LocationID,
    "left").drop("LocationID") # drop duplicate columns 