## Apache Airflow (OPTIONAL) 

In this section of the hands-on-lab, we will utilize Snowpark's Python client-side Dataframe API as well as the Snowpark server-side runtime and Apache Airflow to create an operational pipeline.  We will take the functions created by the ML Ops team and create a directed acyclic graph (DAG) of operations to run each month when new data is available. 

Note: This code requires the ability to run docker containers locally.  If you do not have Docker Desktop you can run the same pipeline from a python kernel via the 04_ML_Ops.ipynb notebook.

We will use the dev CLI from Astronomer. https://docs.astronomer.io/astro/cli/get-started#step-1-install-the-astro-cli

Follow the instructions to install the `astro` CLI for your particular local setup.

In [None]:
%%writefile dags/airflow_tasks.py

from airflow.decorators import task

@task.virtualenv(python_version=3.8)
def snowpark_database_setup(state_dict:dict)-> dict: 
    import snowflake.snowpark.functions as F
    from dags.snowpark_connection import snowpark_connect
    from dags.elt import reset_database

    session, _ = snowpark_connect('./include/state.json')
    reset_database(session=session, state_dict=state_dict, prestaged=True)

    _ = session.sql('CREATE STAGE '+state_dict['model_stage_name']).collect()
    _ = session.sql('CREATE TAG model_id_tag').collect()

    session.close()

    return state_dict

@task.virtualenv(python_version=3.8)
def incremental_elt_task(state_dict: dict, files_to_download:list)-> dict:
    from dags.ingest import incremental_elt
    from dags.snowpark_connection import snowpark_connect

    session, _ = snowpark_connect()

    print('Ingesting '+str(files_to_download))
    download_base_url=state_dict['connection_parameters']['download_base_url']

    _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse'])

    _ = incremental_elt(session=session, 
                        state_dict=state_dict, 
                        files_to_ingest=files_to_download,
                        download_base_url=download_base_url,
                        use_prestaged=True)

    #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\
    #                ' SUSPEND').collect()

    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def initial_bulk_load_task(state_dict:dict)-> dict:
    from dags.ingest import bulk_elt
    from dags.snowpark_connection import snowpark_connect

    session, _ = snowpark_connect()

    _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse'])

    print('Running initial bulk ingest from '+state_dict['connection_parameters']['download_base_url'])
    
    _ = bulk_elt(session=session, 
                 state_dict=state_dict, 
                 download_base_url=state_dict['connection_parameters']['download_base_url'],
                 use_prestaged=True)

    #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\
    #                ' SUSPEND').collect()

    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def materialize_holiday_task(state_dict: dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import materialize_holiday_table

    print('Materializing holiday table.')
    session, _ = snowpark_connect()

    _ = materialize_holiday_table(session=session, 
                                  holiday_table_name=state_dict['holiday_table_name'])

    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def subscribe_to_weather_data_task(state_dict: dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import subscribe_to_weather_data

    print('Subscribing to weather data')
    session, _ = snowpark_connect()

    _ = subscribe_to_weather_data(session=session, 
                                  weather_database_name=state_dict['weather_database_name'], 
                                  weather_listing_id=state_dict['weather_listing_id'])
    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def create_weather_view_task(state_dict: dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import create_weather_view

    print('Creating weather view')
    session, _ = snowpark_connect()

    _ = create_weather_view(session=session,
                            weather_table_name=state_dict['weather_table_name'],
                            weather_view_name=state_dict['weather_view_name'])
    session.close()
    return state_dict
    
@task.virtualenv(python_version=3.8)
def deploy_model_udf_task(state_dict:dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import deploy_pred_train_udf

    print('Deploying station model')
    session, _ = snowpark_connect()

    _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect()

    _ = deploy_pred_train_udf(session=session, 
                              udf_name=state_dict['train_udf_name'],
                              function_name=state_dict['train_func_name'],
                              model_stage_name=state_dict['model_stage_name'])
    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def deploy_eval_udf_task(state_dict:dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import deploy_eval_udf

    print('Deploying station model')
    session, _ = snowpark_connect()

    _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect()

    _ = deploy_eval_udf(session=session, 
                        udf_name=state_dict['eval_udf_name'],
                        function_name=state_dict['eval_func_name'],
                        model_stage_name=state_dict['model_stage_name'])
    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def generate_feature_table_task(state_dict:dict, 
                                holiday_state_dict:dict, 
                                weather_state_dict:dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import create_feature_table

    print('Generating features for all stations.')
    session, _ = snowpark_connect()

    session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])

    _ = session.sql("CREATE OR REPLACE TABLE "+state_dict['clone_table_name']+\
                    " CLONE "+state_dict['trips_table_name']).collect()
    _ = session.sql("ALTER TABLE "+state_dict['clone_table_name']+\
                    " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()

    _ = create_feature_table(session, 
                             trips_table_name=state_dict['clone_table_name'], 
                             holiday_table_name=state_dict['holiday_table_name'], 
                             weather_view_name=state_dict['weather_view_name'],
                             feature_table_name=state_dict['feature_table_name'])

    _ = session.sql("ALTER TABLE "+state_dict['feature_table_name']+\
                    " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()

    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def generate_forecast_table_task(state_dict:dict, 
                                 holiday_state_dict:dict, 
                                 weather_state_dict:dict)-> dict: 
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import create_forecast_table

    print('Generating forecast features.')
    session, _ = snowpark_connect()

    _ = create_forecast_table(session, 
                              trips_table_name=state_dict['trips_table_name'],
                              holiday_table_name=state_dict['holiday_table_name'], 
                              weather_view_name=state_dict['weather_view_name'], 
                              forecast_table_name=state_dict['forecast_table_name'],
                              steps=state_dict['forecast_steps'])

    _ = session.sql("ALTER TABLE "+state_dict['forecast_table_name']+\
                    " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()

    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def bulk_train_predict_task(state_dict:dict, 
                            feature_state_dict:dict, 
                            forecast_state_dict:dict)-> dict: 
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import train_predict

    state_dict = feature_state_dict

    print('Running bulk training and forecast.')
    session, _ = snowpark_connect()

    session.use_warehouse(state_dict['compute_parameters']['train_warehouse'])

    pred_table_name = train_predict(session, 
                                    station_train_pred_udf_name=state_dict['train_udf_name'], 
                                    feature_table_name=state_dict['feature_table_name'], 
                                    forecast_table_name=state_dict['forecast_table_name'],
                                    pred_table_name=state_dict['pred_table_name'])

    _ = session.sql("ALTER TABLE "+state_dict['pred_table_name']+\
                    " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()
    #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['train_warehouse']+\
    #                ' SUSPEND').collect()

    session.close()
    return state_dict

@task.virtualenv(python_version=3.8)
def eval_station_models_task(state_dict:dict, 
                             pred_state_dict:dict,
                             run_date:str)-> dict:

    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import evaluate_station_model

    print('Running eval UDF for model output')
    session, _ = snowpark_connect()

    eval_table_name = evaluate_station_model(session, 
                                             run_date=run_date, 
                                             eval_model_udf_name=state_dict['eval_udf_name'], 
                                             pred_table_name=state_dict['pred_table_name'], 
                                             eval_table_name=state_dict['eval_table_name'])

    _ = session.sql("ALTER TABLE "+state_dict['eval_table_name']+\
                    " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()
    session.close()
    return state_dict                                               

@task.virtualenv(python_version=3.8)
def flatten_tables_task(pred_state_dict:dict, state_dict:dict)-> dict:
    from dags.snowpark_connection import snowpark_connect
    from dags.mlops_pipeline import flatten_tables

    print('Flattening tables for end-user consumption.')
    session, _ = snowpark_connect()

    flat_pred_table, flat_forecast_table, flat_eval_table = flatten_tables(session,
                                                                           pred_table_name=state_dict['pred_table_name'], 
                                                                           forecast_table_name=state_dict['forecast_table_name'], 
                                                                           eval_table_name=state_dict['eval_table_name'])
    state_dict['flat_pred_table'] = flat_pred_table
    state_dict['flat_forecast_table'] = flat_forecast_table
    state_dict['flat_eval_table'] = flat_eval_table

    _ = session.sql("ALTER TABLE "+flat_pred_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()
    _ = session.sql("ALTER TABLE "+flat_forecast_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()
    _ = session.sql("ALTER TABLE "+flat_eval_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect()

    return state_dict


In [None]:
%%writefile dags/airflow_setup_pipeline.py

from datetime import datetime, timedelta

from airflow.decorators import dag, task
from dags.airflow_tasks import snowpark_database_setup
from dags.airflow_tasks import incremental_elt_task
from dags.airflow_tasks import initial_bulk_load_task
from dags.airflow_tasks import materialize_holiday_task
from dags.airflow_tasks import subscribe_to_weather_data_task
from dags.airflow_tasks import create_weather_view_task
from dags.airflow_tasks import deploy_model_udf_task
from dags.airflow_tasks import deploy_eval_udf_task
from dags.airflow_tasks import generate_feature_table_task
from dags.airflow_tasks import generate_forecast_table_task
from dags.airflow_tasks import bulk_train_predict_task
from dags.airflow_tasks import eval_station_models_task 
from dags.airflow_tasks import flatten_tables_task

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5)
}

#local_airflow_path = '/usr/local/airflow/'

@dag(default_args=default_args, schedule_interval=None, start_date=datetime(2020, 3, 1), catchup=False, tags=['setup'])
def citibikeml_setup_taskflow(run_date:str):
    """
    Setup initial Snowpark / Astronomer ML Demo
    """
    import uuid
    import json
    
    with open('./include/state.json') as sdf:
        state_dict = json.load(sdf)
    
    model_id = str(uuid.uuid1()).replace('-', '_')

    state_dict.update({'model_id': model_id})
    state_dict.update({'run_date': run_date})
    state_dict.update({'weather_database_name': 'WEATHER_NYC'})
    state_dict.update({'load_table_name': 'RAW_',
                       'trips_table_name': 'TRIPS',
                       'load_stage_name': 'LOAD_STAGE',
                       'model_stage_name': 'MODEL_STAGE',
                       'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY',
                       'weather_view_name': 'WEATHER_NYC_VW',
                       'holiday_table_name': 'HOLIDAYS',
                       'clone_table_name': 'CLONE_'+model_id,
                       'feature_table_name' : 'FEATURE_'+model_id,
                       'pred_table_name': 'PRED_'+model_id,
                       'eval_table_name': 'EVAL_'+model_id,
                       'forecast_table_name': 'FORECAST_'+model_id,
                       'forecast_steps': 30,
                       'train_udf_name': 'station_train_predict_udf',
                       'train_func_name': 'station_train_predict_func',
                       'eval_udf_name': 'eval_model_output_udf',
                       'eval_func_name': 'eval_model_func'
                      })
    
    #Task order - one-time setup
    setup_state_dict = snowpark_database_setup(state_dict)
    load_state_dict = initial_bulk_load_task(setup_state_dict)
    holiday_state_dict = materialize_holiday_task(setup_state_dict)
    subscribe_state_dict = subscribe_to_weather_data_task(setup_state_dict)
    weather_state_dict = create_weather_view_task(subscribe_state_dict)
    model_udf_state_dict = deploy_model_udf_task(setup_state_dict)
    eval_udf_state_dict = deploy_eval_udf_task(setup_state_dict)
    feature_state_dict = generate_feature_table_task(load_state_dict, holiday_state_dict, weather_state_dict) 
    foecast_state_dict = generate_forecast_table_task(load_state_dict, holiday_state_dict, weather_state_dict)
    pred_state_dict = bulk_train_predict_task(model_udf_state_dict, feature_state_dict, foecast_state_dict)
    eval_state_dict = eval_station_models_task(eval_udf_state_dict, pred_state_dict, run_date)  
    state_dict = flatten_tables_task(pred_state_dict, eval_state_dict)

    return state_dict

run_date='2020_01_01'

state_dict = citibikeml_setup_taskflow(run_date=run_date)


In [None]:
%%writefile dags/airflow_incremental_pipeline.py

from datetime import datetime, timedelta

from airflow.decorators import dag, task
from dags.airflow_tasks import snowpark_database_setup
from dags.airflow_tasks import incremental_elt_task
from dags.airflow_tasks import initial_bulk_load_task
from dags.airflow_tasks import materialize_holiday_task
from dags.airflow_tasks import deploy_model_udf_task
from dags.airflow_tasks import deploy_eval_udf_task
from dags.airflow_tasks import generate_feature_table_task
from dags.airflow_tasks import generate_forecast_table_task
from dags.airflow_tasks import bulk_train_predict_task
from dags.airflow_tasks import eval_station_models_task 
from dags.airflow_tasks import flatten_tables_task

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5)
}

#local_airflow_path = '/usr/local/airflow/'

@dag(default_args=default_args, schedule_interval=None, start_date=datetime(2020, 4, 1), catchup=False, tags=['monthly'])
def citibikeml_monthly_taskflow(files_to_download:list, run_date:str):
    """
    End to end Snowpark / Astronomer ML Demo
    """
    import uuid
    import json
    
    with open('./include/state.json') as sdf:
        state_dict = json.load(sdf)
    
    model_id = str(uuid.uuid1()).replace('-', '_')

    state_dict.update({'model_id': model_id})
    state_dict.update({'run_date': run_date})
    state_dict.update({'weather_database_name': 'WEATHER_NYC'})
    state_dict.update({'load_table_name': 'RAW_',
                       'trips_table_name': 'TRIPS',
                       'load_stage_name': 'LOAD_STAGE',
                       'model_stage_name': 'MODEL_STAGE',
                       'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY',
                       'weather_view_name': 'WEATHER_NYC_VW',
                       'holiday_table_name': 'HOLIDAYS',
                       'clone_table_name': 'CLONE_'+model_id,
                       'feature_table_name' : 'FEATURE_'+model_id,
                       'pred_table_name': 'PRED_'+model_id,
                       'eval_table_name': 'EVAL_'+model_id,
                       'forecast_table_name': 'FORECAST_'+model_id,
                       'forecast_steps': 30,
                       'train_udf_name': 'station_train_predict_udf',
                       'train_func_name': 'station_train_predict_func',
                       'eval_udf_name': 'eval_model_output_udf',
                       'eval_func_name': 'eval_model_func'
                      })

    incr_state_dict = incremental_elt_task(state_dict, files_to_download)
    feature_state_dict = generate_feature_table_task(incr_state_dict, incr_state_dict, incr_state_dict) 
    forecast_state_dict = generate_forecast_table_task(incr_state_dict, incr_state_dict, incr_state_dict)
    pred_state_dict = bulk_train_predict_task(feature_state_dict, feature_state_dict, forecast_state_dict)
    eval_state_dict = eval_station_models_task(pred_state_dict, pred_state_dict, run_date)
    state_dict = flatten_tables_task(pred_state_dict, eval_state_dict)

    return state_dict

run_date='2020_02_01'
files_to_download = ['202001-citibike-tripdata.csv.zip']

state_dict = citibikeml_monthly_taskflow(files_to_download=files_to_download, 
                                         run_date=run_date)


Now open a new browser tab to localhost:8080

In [None]:
import webbrowser

# generate an URL
url = 'https://localhost:8080'
webbrowser.open(url)

Lets run the initial setup, ingest and forecast DAG.

In [None]:
# #This sample code can be used to trigger the Airflow pipeline from a command-line shell.
# !curl -X POST 'http://localhost:8080/api/v1/dags/citibikeml_monthly_taskflow/dagRuns' \
# -H 'Content-Type: application/json' \
# --user "admin:admin" \
# -d '{"conf": {"files_to_download": ["202003-citibike-tripdata.csv.zip"], "run_date": "2020_04_01"}}'

Alternatively we can use a REST API

In [None]:
import requests
from requests.auth import HTTPBasicAuth
import time 
import json

dag_url='http://localhost:8080/api/v1/dags/citibikeml_setup_taskflow/dagRuns'
json_payload = {"conf": {"run_date": "2020_01_01"}}

response = requests.post(dag_url, 
                        json=json_payload,
                        auth = HTTPBasicAuth('admin', 'admin'))

run_id = json.loads(response.text)['dag_run_id']

state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']

while state != 'success':
    print('DAG running...'+state)
    time.sleep(10)
    state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']

In [None]:
import requests
from requests.auth import HTTPBasicAuth
import time 
import json

dag_url='http://localhost:8080/api/v1/dags/citibikeml_monthly_taskflow/dagRuns'
json_payload = {"conf": {"files_to_download": ["202001-citibike-tripdata.csv.zip"], "run_date": "2020_02_01"}}

response = requests.post(dag_url, 
                        json=json_payload,
                        auth = HTTPBasicAuth('admin', 'admin'))

run_id = json.loads(response.text)['dag_run_id']

state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']

while state != 'success':
    print('DAG running...'+state)
    time.sleep(10)
    state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']