# Loading the dataset

In [121]:
from urllib.request import urlretrieve
import os, ssl

# directory path to store data
output_relative_dir = './data'

# check if it exists as it makedir will raise an error if it does exist
if not os.path.exists(output_relative_dir):
    os.makedirs(output_relative_dir)

In [132]:
YEAR = '2021'
MONTHS = range(6,8)
URL_TEMPLATE = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_"#year-month.parquet

In [133]:
tlc_output_dir = output_relative_dir

for month in MONTHS:
    # 0-fill i.e 1 -> 01, 2 -> 02, etc
    month = str(month).zfill(2) 
    print(f"Begin month {month}")
    
    # generate url
    url = f'{URL_TEMPLATE}{YEAR}-{month}.parquet'
    # generate output location and filename
    output_dir = f"{tlc_output_dir}/{YEAR}-{month}.parquet"

    if (not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None)):
        ssl._create_default_https_context = ssl._create_unverified_context
    # download
    urlretrieve(url, output_dir) 
    
    print(f"Completed month {month}")

Begin month 06
Completed month 06
Begin month 07
Completed month 07


In [None]:
from pyspark.sql import SparkSession

# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("ADS")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .getOrCreate()
)

In [None]:
sdf = spark.read.parquet('./data')

In [None]:
from pyspark.sql import functions as F
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
import folium

In [None]:
zones = pd.read_csv("./data/taxi_data/taxi+_zone_lookup.csv")
sf = gpd.read_file("./data/taxi_data/taxi_zones.shp")
# attribute tute code
sf['geometry'] = sf['geometry'].to_crs("+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs")

In [None]:
gdf = gpd.GeoDataFrame(pd.merge(zones, sf, on='LocationID', how='inner'))
# create a JSON 
geoJSON = gdf[['LocationID', 'geometry']] \
    .drop_duplicates('LocationID').to_json()

In [None]:
gdf['wkt'] = gdf['geometry'].to_wkt()
spark_gdf = spark.createDataFrame(
    gdf[['Zone', 'LocationID', 'wkt']]
)

In [None]:
from shapely import wkt
from pyspark.sql.types import ArrayType, FloatType

@F.udf(ArrayType(FloatType()))
def get_centroids(wkt_geo):
    centroid = wkt.loads(wkt_geo).centroid
    return centroid.y, centroid.x

spark_gdf = spark_gdf.withColumn(
    'geometry',
    get_centroids(F.col('wkt'))
)

In [None]:
spark_gdf.limit(2)

In [None]:
gdf.head(2)

In [None]:
import time, math
from datetime import date

def extract_date_time(date_str):
    """
    date string is of the format yyyy-mm-dd hh:mm:ss, e.g., 2022-04-01 00:21:13
    Follows 24-hr time format.
    Return a tuple of the form (time_str, hour_bin, month, date, day, isWeekend)
    """
    date_time = date_str.split()
    if len(date_time) != 2:
        return (None, None, None)

    dateL = list(map(int, date_time[0].split("-")))
    timeL = list(map(int, date_time[1].split(":")))
    
    return dateL, timeL
    

def extract_features(date_str):
    dateL, timeL = extract_date_time(date_str)
    
    time_str = f'{timeL[0]}:{timeL[1]}'
    hour_bin = int(timeL[0])

    dateV = date(dateL[0], dateL[1], dateL[2])

    day_dict = {0: "Monday", 1: "Tuesday", 2: "Wednesday", 3: "Thursday", 
                4: "Friday", 5: "Saturday", 6: "Sunday"}
    day = day_dict[dateV.weekday()]
    
    is_weekend = 0
    if day in ["Saturday", "Sunday"]:
        is_weekend = 1

    return (time_str, hour_bin, dateV.month, dateV.day, day, is_weekend)

In [None]:
extract_date_time("2022-04-01 00:21:13")
# date(2022, 8, 19).weekday()

In [None]:
# clean rows with invalid values accoring to the data dictionaries
sdf_clean = sdf.filter(
    (sdf["total_amount"] < 0) |
    (sdf["VendorID"] > 2) | 
    (sdf["VendorID"] < 1) |
    (sdf["passenger_count"] < 1) | 
    (sdf["trip_distance"] <= 0) | 
    (sdf["RatecodeID"] > 6) | 
    (sdf["RatecodeID"] < 1) | 
    (sdf["payment_type"] > 6) | 
    (sdf["payment_type"] < 1)
    )

In [None]:
small_df = sdf_clean.sample(0.05, seed=None)
small_df = small_df.toPandas()

In [None]:
# sdf.filter(F.col('passenger_count') >= 2).count()
CORR_COLS = ["trip_distance", "PULocationID", "DOLocationID", "total_amount"]
sns.heatmap(small_df[CORR_COLS].corr())

plt.title('Pearson Correlation Metric')
plt.show()

In [None]:
small_df = small_df \
    .merge(gdf[['LocationID', 'geometry']], left_on='PULocationID', right_on='LocationID') \
    .drop('LocationID', axis=1)

In [None]:
sdf_clean.printSchema()

In [None]:
sdf.withColumn(
    'column_to',
    some_udf(F.col('column_from'))
)

for field in ('PU', 'DO'):
    _field = f'{field}LocationID'
    sdf = sdf.withColumn(
        field,
        F.col(_field).cast('INT')
    )