### Python udf

In [None]:
import sys
from scipy.constants import convert_temperature

def main(temp_c: float) -> float:
    """Convert Celsius to Fahrenheit using scipy."""
    return convert_temperature(float(temp_c), 'C', 'F')


# For local debugging
if __name__ == '__main__':
    if len(sys.argv) > 1:
        print(main(float(sys.argv[1])))  # Convert input
    else:
        print("Please provide a temperature value in Celsius.")


## register udf

In [None]:
CREATE OR REPLACE FUNCTION convert_celsius_to_fahrenheit(temp FLOAT)
RETURNS FLOAT
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
HANDLER = 'main'
PACKAGES = ('scipy')
AS
$$
import sys
from scipy.constants import convert_temperature

def main(temp_c: float) -> float:
    return convert_temperature(float(temp_c), 'C', 'F')
$$;


### Sql udf

In [None]:
CREATE OR REPLACE FUNCTION categorize_temperature(temp FLOAT)
RETURNS STRING
LANGUAGE SQL
AS 
$$
    CASE 
        WHEN temp < 0 THEN 'Freezing'
        WHEN temp BETWEEN 0 AND 10 THEN 'Cold'
        WHEN temp BETWEEN 11 AND 25 THEN 'Mild'
        WHEN temp BETWEEN 26 AND 35 THEN 'Warm'
        ELSE 'Hot'
    END
$$;

In [None]:
#------------------------------------------------------------------------------
# Hands-On Lab: Data Engineering with Snowpark
# Script:       07_daily_metrics_process_sp/app.py
# Author:       Jeremiah Hansen, Caleb Baechtold
# Last Updated: 1/9/2023
#------------------------------------------------------------------------------

import time
from snowflake.snowpark import Session
import snowflake.snowpark.types as T
import snowflake.snowpark.functions as F


def table_exists(session, schema='', name=''):
    exists = session.sql("SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS".format(schema, name)).collect()[0]['TABLE_EXISTS']
    return exists

def create_daily_metrics_table(session):
    SHARED_COLUMNS= [T.StructField("DATE", T.DateType()),
                                        T.StructField("ZIP_CODE", T.StringType()),
                                        T.StructField("STATION", T.StringType()),
                                        T.StructField("AVG_TEMPERATURE_FAHRENHEIT", T.DecimalType()),
                                        T.StructField("AVG_TEMPERATURE_CELSIUS", T.DecimalType()),
                                    ]
    DAILY_METRICS_COLUMNS = [*SHARED_COLUMNS, T.StructField("META_UPDATED_AT", T.TimestampType())]
    DAILY_METRICS_SCHEMA = T.StructType(DAILY_METRICS_COLUMNS)

    dcm = session.create_dataframe([[None]*len(DAILY_METRICS_SCHEMA.names)], schema=DAILY_METRICS_SCHEMA) \
                        .na.drop() \
                        .write.mode('overwrite').save_as_table('ANALYTICS.DAILY_METRICS')
    dcm = session.table('ANALYTICS.DAILY_METRICS')


def merge_daily_metrics(session):
    _ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE').collect()

    weather = session.table("PUBLIC.WEATHER_DATA")

    weather_agg = weather.group_by(F.col('DATE'), F.col('ZIP_CODE'), F.col('STATION')) \
                        .agg( \
                            F.avg('TEMPERATURE').alias("AVG_TEMPERATURE_F"), \
                            F.avg(F.call_udf("ANALYTICS.FAHRENHEIT_TO_CELSIUS_UDF", F.col("TEMPERATURE"))).alias("AVG_TEMPERATURE_C"), \
                        ) \
                        .select(F.col("DATE"), F.col("ZIP_CODE"), F.col("STATION"), \
                            F.round(F.col("AVG_TEMPERATURE_F"), 2).alias("AVG_TEMPERATURE_FAHRENHEIT"), \
                            F.round(F.col("AVG_TEMPERATURE_C"), 2).alias("AVG_TEMPERATURE_CELSIUS")
                            )
#    weather_agg.limit(5).show()

    cols_to_update = {c: weather_agg[c] for c in weather_agg.schema.names}
    metadata_col_to_update = {"META_UPDATED_AT": F.current_timestamp()}
    updates = {**cols_to_update, **metadata_col_to_update}

    dcm = session.table('ANALYTICS.DAILY_METRICS')
    dcm.merge(weather_agg, (dcm['DATE'] == weather_agg['DATE']) & (dcm['ZIP_CODE'] == weather_agg['ZIP_CODE']) & (dcm['STATION'] == weather_agg['STATION']), \
                        [F.when_matched().update(updates), F.when_not_matched().insert(updates)])

    _ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL').collect()

def main(session: Session) -> str:
    # Create the DAILY_METRICS table if it doesn't exist
    if not table_exists(session, schema='ANALYTICS', name='DAILY_METRICS'):
        create_daily_metrics_table(session)
    
    merge_daily_metrics(session)
#    session.table('ANALYTICS.DAILY_METRICS').limit(5).show()

    return f"Successfully processed DAILY_METRICS"


# For local debugging
# Be aware you may need to type-convert arguments if you add input parameters
if __name__ == '__main__':
    # Create a local Snowpark session
    with Session.builder.getOrCreate() as session:
        import sys
        if len(sys.argv) > 1:
            print(main(session, *sys.argv[1:]))  # type: ignore
        else:
            print(main(session))  # type: ignore