%md
# Spark UDF
- A User-Defined Function (UDF) in PySpark is used when you need to apply custom logic that cannot be easily expressed using built-in Spark SQL functions.
- The functions has to be written in native python and wrapped in pyspark function for distributed usage.
- UDFs allow you to wrap any Python function and use it with Spark DataFrames.
- Steps
  - Define the function:
    def upper_case(name):
    return name.upper()
  - Register as UDF:
    upper_udf = udf(upper_case, StringType())
  - Apply UDF to DataFrame:
    df.withColumn("name_upper", upper_udf(df["name"]))
- UDFs are not optimized by Catalyst (Spark’s optimization engine).
- They run in a separate Python process (via Py4J), which slows performance.

## Why wrap the function:
- PySpark needs to convert a normal Python function into a Spark-compatible expression that can be applied to a DataFrame column.
- Spark runs distributed, Your Python function (e.g., upper_case) runs on the driver node.
- Data lives on executor nodes.
- Spark needs to serialize and ship your logic to all the executors.
- So, udf() acts like a translator that Prepares the function for execution on distributed data and Manages type safety 

## Registering Function:
- 2 ways one for dataframe col object and one for sql expression
- The registeration process is difference
- The dataframe UDF method : complete_date_udf = udf( flight_date_generator,DateType())
  - will not register the udf to catalogue
  - use this if you want to use the UDF in dataframe
- For sql based : spark.udf.register(complete_date_udf,flight_date_generator,DateType())
  - Register as sql function and creates entry in the catalogue
  - Use this if you want to use the function in SQL expression

In [3]:
# loading the dataset
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

if __name__ == "__main__":

    spark = (
        SparkSession.builder
        .appName('Spark col transformation')
        .getOrCreate()
    )
             

    flights_df = (
        spark.read
        .format('csv')
        .option('inferSchema','true')
        .option('header','true')
        .option('samplingRatio','0.001')
        .load(
            path = r'C:\Users\shubh\OneDrive\Documents\Visual Studio 2017\datasets\flights.csv', # gotta download the flights dataset to work
            encoding = 'utf-8'
        )   
    )

In [10]:
import datetime

def flight_date_generator(year, month, day):
    try:
        return datetime.datetime.strptime(f"{int(year):04d}{int(month):02d}{int(day):02d}", "%Y%m%d").date()
    except:
        return None

complete_date_udf = udf( flight_date_generator,DateType())

transformed_df = (
  flights_df.withColumn(
    'FlightDate',
    complete_date_udf(
      flights_df['Year'],
      flights_df['Month'],
      flights_df['DayofMonth']
    )
  ).withColumn(
      'Max Year', year(col('FlightDate'))
  ).withColumn(
      'Min Year', year(col('FlightDate'))
  ).select(
    'Origin',
    'Dest',
    'FlightNum',
    'FlightDate',
    'Max Year',
    'Min Year',
    'Year',
    'Distance'
  )
)

transformed_df.show()

#learnings:
  # the function has to be in core python and not use the pyspark functions
  # the withcolumn transformation has to be first as the select is an action


+------+----+---------+----------+--------+--------+----+--------+
|Origin|Dest|FlightNum|FlightDate|Max Year|Min Year|Year|Distance|
+------+----+---------+----------+--------+--------+----+--------+
|   IAD| TPA|      335|2008-01-03|    2008|    2008|2008|     810|
|   IAD| TPA|     3231|2008-01-03|    2008|    2008|2008|     810|
|   IND| BWI|      448|2008-01-03|    2008|    2008|2008|     515|
|   IND| BWI|     1746|2008-01-03|    2008|    2008|2008|     515|
|   IND| BWI|     3920|2008-01-03|    2008|    2008|2008|     515|
|   IND| JAX|      378|2008-01-03|    2008|    2008|2008|     688|
|   IND| LAS|      509|2008-01-03|    2008|    2008|2008|    1591|
|   IND| LAS|      535|2008-01-03|    2008|    2008|2008|    1591|
|   IND| MCI|       11|2008-01-03|    2008|    2008|2008|     451|
|   IND| MCI|      810|2008-01-03|    2008|    2008|2008|     451|
|   IND| MCO|      100|2008-01-03|    2008|    2008|2008|     828|
|   IND| MCO|     1333|2008-01-03|    2008|    2008|2008|     

In [None]:
# same transformation using case when and col object
transformed_df.withColumn(
    'Decade',
    expr(
        """
        Case 
            When Year between 2000 and 2010 Then '2000s'
            else NULL
        END
        """
    )
    ).select(
    'Origin',
    'Dest',
    'FlightNum',
    'FlightDate',
    'Year',
    'Decade',
    'Distance'
).show(5)

+------+----+---------+----------+----+------+--------+
|Origin|Dest|FlightNum|FlightDate|Year|Decade|Distance|
+------+----+---------+----------+----+------+--------+
|   IAD| TPA|      335|2008-01-03|2008| 2000s|     810|
|   IAD| TPA|     3231|2008-01-03|2008| 2000s|     810|
|   IND| BWI|      448|2008-01-03|2008| 2000s|     515|
|   IND| BWI|     1746|2008-01-03|2008| 2000s|     515|
|   IND| BWI|     3920|2008-01-03|2008| 2000s|     515|
+------+----+---------+----------+----+------+--------+
only showing top 5 rows

