## Psark session

In [None]:
# helper functions for spark

from pyspark.sql import SparkSession, DataFrame

def get_spark() -> SparkSession:
  '''Setup function for spark application. Function creates DatabricksSession or SparkSession class object 
  depending on presense of SPARK-REMOTE variable.
    Returns:
        SparkSession: SparkSession or DatabricksSession object
  '''
  try:
    from databricks.connect import DatabricksSession
    return DatabricksSession.builder.getOrCreate()
  except ImportError:
    return SparkSession.builder.getOrCreate()

def get_taxis(spark: SparkSession) -> DataFrame:
  return spark.read.table("samples.nyctaxi.trips")

get_taxis(get_spark()).show(5) 

## Prepare the data

* Load the data into a dataframe from the source with the given schema

In [None]:
from pyspark.sql.functions import col, lit, expr, when
from pyspark.sql.types import *
from datetime import datetime
import time
 
# Define schema
nyc_schema = StructType([
  StructField('vendor', StringType(), True),
  StructField('pickup_datetime', TimestampType(), True),
  StructField('dropoff_datetime', TimestampType(), True),
  StructField('passenger_count', IntegerType(), True),
  StructField('trip_distance', DoubleType(), True),
  StructField('pickup_longitude', DoubleType(), True),
  StructField('pickup_latitude', DoubleType(), True),
  StructField('rate_code', StringType(), True),
  StructField('store_and_forward', StringType(), True),
  StructField('dropoff_longitude', DoubleType(), True),
  StructField('dropoff_latitude', DoubleType(), True),
  StructField('payment_type', StringType(), True),
  StructField('fare_amount', DoubleType(), True),
  StructField('surcharge', DoubleType(), True),
  StructField('mta_tax', DoubleType(), True),
  StructField('tip_amount', DoubleType(), True),
  StructField('tolls_amount', DoubleType(), True),
  StructField('total_amount', DoubleType(), True)
])
 
# Just one file, extend to make all the data available at some point.
yellow = spark.read.format('csv').options(header=True).schema(nyc_schema).load("dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2019-12.csv.gz")
green = spark.read.format('csv').options(header=True).schema(nyc_schema).load("dbfs:/databricks-datasets/nyctaxi/tripdata/green/green_tripdata_2019-12.csv.gz")
df = yellow.unionAll(green)
df

DataFrame[vendor: string, pickup_datetime: timestamp, dropoff_datetime: timestamp, passenger_count: int, trip_distance: double, pickup_longitude: double, pickup_latitude: double, rate_code: string, store_and_forward: string, dropoff_longitude: double, dropoff_latitude: double, payment_type: string, fare_amount: double, surcharge: double, mta_tax: double, tip_amount: double, tolls_amount: double, total_amount: double]

## Transform the data

This step is reducing the number of columns and renames them a bit to be more human friendly. Once the reduction of columns and fields has happened, add a couple of new columns with additional transformation of the data

* passenger type -> only capture single vs multi
* tolls flag -> true / false depending if the data has tolls
* amount -> round the value down

In [None]:
restricted = df.select(df.vendor, df.passenger_count, df.trip_distance, df.fare_amount.alias("amount"), df.total_amount.alias("total"), df.tolls_amount.alias("tolls"))

In [None]:
import pyspark.sql.functions as F

cleaned = restricted.withColumn(
  "passenger_type", F.when(restricted.passenger_count > 1, "multi").otherwise("single")
  ).withColumn(
    "has_tolls", F.when(restricted.tolls > 0, True).otherwise(False)
  ).withColumn(
    "amount_rounded", F.ceil(restricted.amount)
  )
cleaned = cleaned.select(cleaned.vendor, cleaned.passenger_type, cleaned.has_tolls, cleaned.trip_distance.alias("distance"), cleaned.amount_rounded.alias("amount"))

Now, filter out values that don't make sense in the grand scheme of things

* negative distance
* netative amount
* weird correlation between distance and amount

In [None]:
cleaned.count()

7346944

In [None]:

filtered = cleaned.filter(cleaned.distance > 0).filter(cleaned.amount > 0)
filtered.count()

6910494

In [None]:
filter_without_weird = filtered.filter(
  ~((filtered.distance < 5) & (filtered.amount > 100))
)

filter_without_weird.count()

6910487

In [None]:
filter_without_weird.write.saveAsTable("main.martingrund.cleaned_trips")

## Create two views with the data

for the different vendors

In [None]:
spark.sql("""
          create or replace view main.martingrund.trips_yellow as select * except(vendor) from main.martingrund.cleaned_trips where vendor = 1;
          """)

DataFrame[]

In [None]:
spark.sql("""
          create or replace view main.martingrund.trips_green as select * except(vendor) from main.martingrund.cleaned_trips where vendor = 2;
          """)

DataFrame[]

In [None]:
spark.table("main.martingrund.trips_yellow").count()

2256012

In [None]:
spark.table("main.martingrund.trips_green").count()

4623027