## Apache Spark: UDF example
> Download the dataset from [the official TLC Trip Record Data website](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page)

### This cell only shows how to document code
```python
# Load file
local_file = 'datasets/your-downloaded-from-TLC-taxis-file-here.parquet'

# Show data
spark.read.parquet(local_file).show()
```

In [None]:
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import IntegerType

### What is master(local N)?
The --master option specifies the master URL for a distributed cluster, or local to run locally with one thread, or local[N] to run locally with N threads.

<b>Source</b>: See Spark [docs here](spark.apache.org/docs/latest). See all [options here](https://spark.apache.org/docs/latest/submitting-applications.html#master-urls)

In [None]:
# Create SparkSession
spark = SparkSession.builder\
             .master("local[1]")\
             .appName("spark-app-version-x")\
             .getOrCreate()

In [None]:
# Read taxi data
local_file = 'datasets/parquet/'
df = spark.read.parquet(local_file)

In [None]:
# DF is like a relation table in memory. Let's see the columns
df.printSchema()

### Let's create a real "dimension" table, for our RateCodeID
1. Standard rate
2. JFK
3. Newark
4. Nassau or Westchester 
5. Negotiated fare
6. Group ride

### Instead of SQL JOIN and DF.na.fill, let's create a Spark UDF

In [None]:
# Define a Python function
@udf
def tag_rate_code_udf(rate_code_id):
    if rate_code_id == 1:
        return "Standard rate"
    elif rate_code_id == 2:
        return "JFK"
    elif rate_code_id == 3:
        return "Newark"
    elif rate_code_id == 4:
        return "Nassau or Westchester"
    elif rate_code_id == 5:
        return "Negotiated fare"
    elif rate_code_id == 6:
        return "Group ride"
    else:
        return "rate not available"

In [None]:
# Apply the UDF to the DataFrame
df.withColumn("RateCodeName", tag_rate_code_udf(df["RatecodeID"])).select('VendorID','tpep_pickup_datetime','RatecodeID','RateCodeName').show(n=5)

In [None]:
# Confirm the UDF works:
df.withColumn("RateCodeName", tag_rate_code_udf(df["RatecodeID"])).where("RatecodeID is NULL").select('VendorID','tpep_pickup_datetime','RatecodeID','RateCodeName').show(n=5)

In [None]:
# Apply the UDF to the DataFrame
df_na_rate_codes = df.withColumn("RateCodeName", tag_rate_code_udf(df["RatecodeID"]))

In [None]:
df_na_rate_codes.select('VendorID','tpep_pickup_datetime','RatecodeID','RateCodeName').where("RatecodeID is NULL").show(n=5)

### Or simply look for NULL and replace, using a UDF (probably bad performance)

In [None]:
# Define a Python function
@udf
def tag_null_rate_codes_udf(rate_code_id):
    if (rate_code_id != rate_code_id) | (rate_code_id is None):
        return 0
    else:
        return rate_code_id

In [None]:
%%time

# Apply the UDF to the DataFrame, using the same RatecodeID col
df_na_rate_codes = df.withColumn("RatecodeID", tag_null_rate_codes_udf(df["RatecodeID"]))
df_na_rate_codes.show(n=3)

In [None]:
# Show data
df_na_rate_codes.select('VendorID','tpep_pickup_datetime','RatecodeID').where('RatecodeID = 0').show(n=5)

In [None]:
# Show data
df_na_rate_codes.select('VendorID','tpep_pickup_datetime','RatecodeID').where('RatecodeID is NULL').count()

In [None]:
df_na_rate_codes.explain()

In [None]:
# Show data
df_na_rate_codes.groupBy('RatecodeID').count().orderBy('RatecodeID').show()

In [None]:
# Show data
df.groupBy('RatecodeID').count().orderBy('RatecodeID').show()

---
### Using Pandas UDF
- Dependencies
```
❯ source venv/bin/activate
❯ pip3 install pandas pyarrow
```

In [None]:
# Define a Python function
@pandas_udf("double", PandasUDFType.SCALAR)
def tag_null_rate_codes_pudf(rate_code_id):
    if (rate_code_id is None):
        return 0
    else:
        return rate_code_id

In [None]:
# Apply the UDF to the DataFrame, using the same RatecodeID col
df_na_rate_codes_pudf = df.withColumn("RatecodeID", tag_null_rate_codes_pudf(df["RatecodeID"]))
df_na_rate_codes_pudf.show(n=3)

In [None]:
# Stop the session
spark.stop()