In [None]:
#!pip -q install "apache-airflow[celery]==2.2.3" --constraint https://raw.githubusercontent.com/apache/airflow/constraints-2.2.3/constraints-no-providers-3.8.txt

In [None]:
#from datetime import datetime, timedelta
#from webbrowser import get
#from xml.etree.ElementInclude import include
#import uuid

from airflow.decorators import dag, task

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5)
}

files_to_download = ['202003-citibike-tripdata.csv.zip']


In [None]:
%%writefile snowpark_connection.py

def snowpark_connect():
    import os, json 
    import snowflake.snowpark as snp

#     local_airflow_path = '/usr/local/airflow/'
#     with open(os.path.join(local_airflow_path, 'include', 'creds.json')) as f:    
    with open(os.path.join('creds.json')) as f:
        data = json.load(f)
        connection_parameters = {
        'account': data['account'],
        'user': data['username'],
        'password': data['password'],
        'role': data['role'],
        'warehouse': data['warehouse'],
        'database': data['database'],
        'schema': data['schema']}
        session = snp.Session.builder.configs(connection_parameters).create()
    return session

In [None]:
@dag(default_args=default_args, schedule_interval=None, start_date=datetime(2022, 1, 24), catchup=False, tags=['test'])
def snowpark_citibike_ml_taskflow(files_to_download:list):
    """
    End to end Astronomer / Snowflake ML Demo
    """

    import uuid
    
    state_dict = {
    "download_base_url":"https://s3.amazonaws.com/tripdata/",
    "load_table_name":"RAW_",
    "trips_table_name":"TRIPS",
    "load_stage_name":"LOAD_STAGE",
    "model_stage_name":"MODEL_STAGE",
    "model_id": str(uuid.uuid1()).replace('-', '_')
    }
    
    @task()
    def snowpark_database_setup(state_dict:dict)-> dict: 
        from snowpark_connection import snowpark_connect
        
        session = snowpark_connect()

        start_date, end_date = session.table(state_dict['trips_table_name']) \
                              .select(F.min('STARTTIME'), F.max('STARTTIME')).collect()[0][0:2]
        state_dict.update({"start_date":start_date})
        state_dict.update({"end_date":end_date})
        
        _ = session.sql('CREATE STAGE IF NOT EXISTS ' + str(model_stage_name)).collect()
        _ = session.sql('CREATE STAGE IF NOT EXISTS ' + str(load_stage_name)).collect()
        
        session.close()

        return state_dict
    
    @task()
    def  incremental_elt_task(state_dict: dict, files_to_download:list)-> dict:
        from ingest import incremental_elt
        session = snowpark_connect()
        
        _ = incremental_elt(session=session, 
                            load_stage_name=state_dict['load_stage_name'], 
                            files_to_download=files_to_download, 
                            download_base_url=state_dict['download_base_url'], 
                            load_table_name=state_dict['load_table_name'], 
                            trips_table_name=state_dict['trips_table_name']
                            )
        
        session.close()
        return state_dict
    
    @task()
    def deploy_model_udf_task(state_dict:dict)-> dict:
        from mlops_pipeline import deploy_pred_train_udf
        
        session = snowpark_connect()
        model_udf_name = deploy_pred_train_udf(session=session, 
                                               function_name='station_train_predict_func', 
                                               model_stage_name=state_dict['model_stage_name']
                                              )
                
        state_dict.update({"model_udf_name":model_udf_name})

        session.close()
        return state_dict

    @task()
    def materialize_holiday_task(state_dict: dict)-> dict:
        from mlops_pipeline import materialize_holiday_table
        
        session = snowpark_connect()
        
        holiday_table_name = materialize_holiday_table(session=session,
                                                       trips_table_name=state_dict['trips_table_name'], 
                                                       holiday_table_name='holidays'
                                                      )
        
        state_dict.update({"holiday_table_name":holiday_table_name})

        session.close()
        return state_dict

    @task()
    def materialize_precip_task(state_dict: dict)-> dict:
        from mlops_pipeline import materialize_precip_table

        session = snowpark_connect()
        
        precip_table_name = materialize_precip_table(session=session,
                                                     trips_table_name=state_dict['trips_table_name'], 
                                                     precip_table_name='weather'
                                                    )
        
        state_dict.update({"precip_table_name":precip_table_name})

        session.close()
        return state_dict

    @task()
    def generate_feature_table_task(state_dict:dict, top_n:int)-> dict: 
        from mlops_pipeline import generate_feature_table
        
        session = snowpark_connect()
        
        clone_table_name = 'TRIPS_CLONE_'+state_dict["model_id"]
        state_dict.update({"clone_table_name":clone_table_name})
        
        _ = session.sql('CREATE OR REPLACE TABLE '+clone_table_name+" CLONE "+state_dict["trips_table_name"]).collect()
        _ = session.sql('CREATE TAG IF NOT EXISTS model_id_tag').collect()
        _ = session.sql("ALTER TABLE "+clone_table_name+" SET TAG model_id_tag = '"+state_dict["model_id"]+"'").collect()
        
        feature_table_name = generate_feature_table(session=session, 
                                                    clone_table_name=state_dict["clone_table_name"], 
                                                    feature_table_name='TRIPS_FEATURES_'+state_dict["model_id"], 
                                                    holiday_table_name=state_dict["holiday_table_name"],
                                                    precip_table_name=state_dict["precip_table_name"],
                                                    target_column='COUNT', 
                                                    top_n=top_n
                                                   )
        state_dict.update({"feature_table_name":feature_table_name})

        session.close()
        return state_dict
    
    @task()
    def bulk_train_predict_task(state_dict:dict)-> dict: 
        from mlops_pipeline import train_predict_feature_table
        
        session = snowpark_connect()
        pred_table_name = train_predict_feature_table(session=session, 
                                                      station_train_pred_udf_name=state_dict["model_udf_name"], 
                                                      feature_table_name=state_dict["feature_table_name"], 
                                                      pred_table_name='PRED_'+state_dict["model_id"]
                                                     )
        
        state_dict.update({"pred_table_name":pred_table_name})

        session.close()
        return state_dict
    
    def deploy_eval_udf_task(state_dict:dict)-> dict:
        from model_eval import deploy_eval_udf
        
        session = snowpark_connect()
        eval_model_udf_name = deploy_eval_udf(session=session, 
                                              function_name='eval_model_output_func', 
                                              model_stage_name=state_dict['model_stage_name']
                                              )
                
        state_dict.update({"eval_model_udf_name":eval_model_udf_name})

        session.close()
        return state_dict

    @task()
    def eval_station_preds_task(state_dict:dict)-> dict:
        from model_eval import evaluate_station_predictions

        session = snowpark_connect()
        eval_table_name = evaluate_station_predictions(session=session, 
                                                       pred_table_name=state_dict['pred_table_name'],
                                                       eval_model_udf_name=state_dict['eval_model_udf_name'],
                                                       eval_table_name='EVAL_'+state_dict["model_id"]
                                                       )
        state_dict.update({"eval_table_name":eval_table_name})

        session.close()
        return state_dict                                               
    
    #Task order
    state_dict = snowpark_database_setup(state_dict)
    state_dict = incremental_elt_task(state_dict, files_to_download)
    state_dict = deploy_model_udf_task(state_dict)
    state_dict = materialize_holiday_task(state_dict)
    state_dict = materialize_precip_task(state_dict)
    state_dict = generate_feature_table_task(state_dict, top_n) 
    state_dict = bulk_train_predict_task(state_dict)
    state_dict = deploy_eval_udf_task(state_dict)
    state_dict = eval_station_preds_task(state_dict)        
    
    return state_dict
