In [None]:
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import IntegerType, FloatType
from snowflake.snowpark.functions import avg, sum, col, udf, call_udf, call_builtin, year
import pandas as pd
from sklearn.linear_model import LinearRegression
import config

# Session
connection_parameters = {
   "account": config.account,
   "user": config.user,
   "password": config.password,
   "warehouse": config.warehouse,
   "role": config.role,
   "database": config.database,
   "schema": config.schema
}
session = Session.builder.configs(connection_parameters).create()

# test if we have a connection
session.sql("select current_warehouse() wh, current_database() db, current_schema() schema, current_version() v").show()

In [None]:
# Now use Snowpark dataframe
df = (session.table("ECONOMY_DATA_ATLAS.ECONOMY.BEANIPA") 
                            .filter(col('Table Name') == 'Price Indexes For Personal Consumption Expenditures By Major Type Of Product') 
                            .filter(col('Indicator Name') == 'Personal consumption expenditures (PCE)')
                            .filter(col('"Frequency"') == 'A')
                            .filter(col('"Date"') >= '1972-01-01'))
df_agg = df.select(year(col('"Date"')).alias('"Year"'), col('"Value"').alias('PCE') ).to_pandas()
df_agg

In [None]:
# train model linear regression
x = df_agg["Year"].to_numpy().reshape(-1,1)
y = df_agg["PCE"].to_numpy()

model = LinearRegression().fit(x, y)

predictYear = 2021
pce_pred = model.predict([[predictYear]])

In [None]:
# create udf

def predict_pce(year: int) -> float:
    return model.predict([[year]])

predict_pce(2021)

predict_pce_udf = udf(lambda x: predict_pce(x), return_type=FloatType(), input_types=[IntegerType()], packages= ["pandas","scikit-learn"], name = 'predict_pce_udf')

In [None]:
session.sql("select predict_pce_udf(2024)").show()