In [8]:
from functools import partial

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf

from datetime import datetime
from pytz import timezone

import numpy as np
import pandas as pd

In [9]:
spark = (SparkSession
         .builder
         .master("local[*]")
         .appName("taxis")
         .config("spark.executor.cores", "4")
         .getOrCreate())

In [10]:
filepath = "../data/yellow_tripdata_2019-01.csv"
datetime_format = "%Y-%m-%d %H:%M:%S"
QUERY_MONTH = "2019-1"
columns_of_interest = [
    "VendorID", 
    "PULocationID", 
    "trip_distance",
    "fare_amount",
    "payment_type",
    "tpep_pickup_datetime"
] 

In [11]:
df = spark.read.format("csv").option("header", "true").load(filepath)
df = df.select(columns_of_interest)

In [12]:
nytimezone = timezone("US/Eastern")
to_datetime = udf(lambda x: nytimezone.localize(datetime.strptime(x, datetime_format)))

df = df.withColumn("tpep_pickup_datetime", to_datetime("tpep_pickup_datetime"))

In [13]:
def year_month(dt):
    return f"{dt.year}-{dt.month}"

year_month_str = udf(year_month)
df = df.withColumn("year_month", year_month_str("tpep_pickup_datetime"))

In [14]:
df_201901 = df.filter(df["year_month"]=="2019-1")

In [15]:
df.createOrReplaceTempView("taxi_drives")

### Create aggregated trip fact table
There are non-positive distances: we should include "and trip_distance>0"

In [65]:
trip_fact = spark.sql("""
    SELECT VendorID, PULocationID, payment_type, sum(fare_amount) as total_fare_amount, sum(trip_distance) as total_trip_distance 
    FROM taxi_drives 
    WHERE fare_amount > 0 
    GROUP BY VendorID, PULocationID, payment_type""")

In [66]:
trip_fact.createOrReplaceTempView("trip_fact")

In [119]:
# computes distance per vendor and location for payment types 1 and 2
vendor_pu_distance = spark.sql("""
    SELECT VendorID, PULocationID, sum(total_trip_distance) as vendor_pu_distance
    FROM trip_fact
    WHERE payment_type=1 or payment_type=2
    GROUP BY VendorID, PULocationID""")
vendor_pu_distance.createOrReplaceTempView("vendor_pu_distance")

In [120]:
# join the above result with the original trip fact table
trip_vendor_pu_distance = spark.sql("""
    SELECT tf.*, vpd.vendor_pu_distance
    FROM trip_fact as tf
    JOIN vendor_pu_distance as vpd
    ON tf.VendorID=vpd.VendorID and tf.PULocationID=vpd.PULocationID
    WHERE tf.payment_type=1
""")

### Scenario 1

In [122]:
spark.sql("""
    SELECT VendorID, 0.01*sum(total_trip_distance) as tax
    FROM trip_fact
    GROUP BY VendorID
""").show()

+--------+------------------+
|VendorID|               tax|
+--------+------------------+
|       1| 77307.54199999999|
|       4|1965.5500999999995|
|       2| 135357.7713999999|
+--------+------------------+



### Scenario 2

In [114]:
percentages = [
    (0, 10000, 0.0),
    (10000, 30000, 0.1),
    (30000, 70000, 0.2),
    (70000, np.inf,0.3)
]
percentages

[(0, 10000, 0.0), (10000, 30000, 0.1), (30000, 70000, 0.2), (70000, inf, 0.3)]

In [115]:
def add_tax_percentage(distance, percentages):
    for lb, ub, percentage in percentages:
        if lb <= distance < ub:
            return percentage
        
add_tax_percentage_udf = udf(partial(add_tax_percentage, percentages=percentages))

In [116]:
trip_tax_percentage = trip_vendor_pu_distance.withColumn("tax_percentage", add_tax_percentage_udf("vendor_pu_distance"))

In [117]:
trip_tax_percentage.createOrReplaceTempView("trip_tax_percentage")
tax_df = spark.sql("""
    SELECT VendorID, sum(tax_percentage*total_fare_amount) as tax
    FROM trip_tax_percentage
    GROUP BY VendorID
""")

In [118]:
tax_df.show()

+--------+--------------------+
|VendorID|                 tax|
+--------+--------------------+
|       1|         7225624.523|
|       4|           10879.253|
|       2|1.2203507080999998E7|
+--------+--------------------+



### Scenario 3

In [100]:
def progressive_percentages(distance, percentages):
    progressive_cut = 0
    for lb, ub, percentage in percentages:
        if distance > ub:
            progressive_cut += percentage*(ub-lb)
        else:
            progressive_cut += percentage*(distance-lb)
            break
    try:    
        return progressive_cut/distance
    except ZeroDivisionError:
        return 0


progressive_percentages_udf = udf(partial(progressive_percentages, percentages=percentages))

In [106]:
assert progressive_percentages(71000, percentages) == (20000*0.1+40000*0.2+0.3*1000)/71000

In [103]:
trip_tax_percentage = trip_vendor_pu_distance.withColumn("tax_percentage", progressive_percentages_udf("vendor_pu_distance"))

In [104]:
trip_tax_percentage.createOrReplaceTempView("trip_tax_percentage")
tax_df = spark.sql("""
    SELECT VendorID, sum(tax_percentage*total_fare_amount) as tax
    FROM trip_tax_percentage
    GROUP BY VendorID
""")
tax_df.show()

+--------+--------------------+
|VendorID|                 tax|
+--------+--------------------+
|       1|    5534185.76709494|
|       4|   6196.049828009542|
|       2|1.0287641340610534E7|
+--------+--------------------+

