# This notebook is to create a function to ingest data from snowflake with a Dask cluster

The dask frameworks enables users to parallelize their python code and run it as a distributed process on Iguazio cluster and dramatically accelerate their performance. <br>
In this notebook we'll create an mlrun function running as a dask client to ingest data from snowflake. <br>
It also demonstrates how to run parallelize query against snowflake using Dask Delayed option to query a large data set from snowflake. <br>
The function will be published on the function marketplace. <br>
For more information on dask over kubernetes: https://kubernetes.dask.org/en/latest/

### Set up the enviroment

In [None]:
import mlrun
import os
import warnings
import yaml

project_name = "snowflake-dask"
dask_cluster_name="snowflake-dask-cluster"
artifact_path = mlrun.set_environment(project=project_name,
                                      artifact_path = os.path.join(os.path.abspath('/v3io/projects/'), project_name))

warnings.filterwarnings("ignore")

print(f'artifact_path = {artifact_path}')

### Load snowflake configuration from config file. 
This is for demo purpose, in the real production code, you would need to put the snowflake connection info into secrets use the secrets in the running pod to connect to snowflake

In [None]:
# Load connection info
with open(".config.yaml") as f:
    connection_info = yaml.safe_load(f)

# verify the config
print(connection_info['account'])

### Create a python function

This function querys data from snowflake using snowflake python connector for parallel processing of the query results. <br>
With snoeflake python connector, when you execute a query, the cursor will return the result batches. <br>
Using Dask Delayed it will return and process results set in parallel. <br>

#### write the function to a py file

In [None]:
%%writefile snowflake_dask.py
"""Snowflake Dask - Ingest Snowflake data with Dask"""
import warnings
import mlrun
from mlrun.execution import MLClientCtx
import snowflake.connector as snow
from dask.distributed import Client
from dask.dataframe import from_delayed
from dask import delayed
from dask import dataframe as dd
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

warnings.filterwarnings("ignore")

@delayed
def load(batch):

    """A delayed load one batch."""

    try:
        print("BATCHING")
        df_ = batch.to_pandas()
        return df_
    except Exception as e:
        print(f"Failed on {batch} for {e}")
        raise

def load_results(context: MLClientCtx,
                 dask_client: str,
                 connection_info: str,
                 query: str,
                 parquet_out_dir = None,
                 publish_name = None
                ) -> None:

    """Snowflake Dask - Ingest Snowflake data with Dask

    :param context:           the function context
    :param dask_client:       dask cluster function name
    :param connection_info:   Snowflake database connection info (this will be in a secret later)
    :param query:             query to for Snowflake
    :param parquet_out_dir:   directory path for the output parquet files
                              (default None, not write out)
    :param publish_name:      name of the dask dataframe to publish to the dask cluster
                              (default None, not publish)

    """
    context = mlrun.get_or_create_ctx('snawflake-dask-cluster')
    sf_password = context.get_secret('sfPassword')
    pk_path =  context.get_secret('pkPath')
    pk_password =  context.get_secret('pkPassword')

    if pk_path and pk_password:
        with open(pk_path, "rb") as key:
            p_key= serialization.load_pem_private_key(
                key.read(),
                password=str(pk_password).encode(),
                backend=default_backend()
            )
        pkb = p_key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.PKCS8
            ,encryption_algorithm=serialization.NoEncryption()
        )
        connection_info.pop('password', 'No password found')
        connection_info['private_key'] = pkb
    elif sf_password:
        connection_info['password'] = sf_password
    else:
        raise Exception("\nPlease set up the secret for Snowflake in your project!\n")

    # setup dask client from the MLRun dask cluster function
    if dask_client:
        client = mlrun.import_function(dask_client).client
        context.logger.info(f'Existing dask client === >>> {client}\n')
    else:
        client = Client()
        context.logger.info(f'\nNewly created dask client === >>> {client}\n')

    conn = snow.connect(**connection_info)
    cur = conn.cursor()
    cur.execute(query)
    batches = cur.get_result_batches()
    context.logger.info(f'batches len === {len(batches)}\n')

    dfs = []
    for batch in batches:
        if batch.rowcount > 0:
            df = load(batch)
            dfs.append(df)
    ddf = from_delayed(dfs)

    # materialize the query results set for some sample compute

    ddf_describe = ddf.describe().compute()

    context.logger.info(f'query  === >>> {query}\n')
    context.logger.info(f'ddf  === >>> {ddf}\n')
    context.log_result('number of rows', len(ddf.index))
    context.log_dataset("ddf_describe", df=ddf_describe)

    if publish_name:
        context.log_result('data_set_name', publish_name)
        if not client.list_datasets():
            ddf.persist(name = publish_name)
            client.publish_dataset(publish_name=ddf)

    if parquet_out_dir:
        dd.to_parquet(df=ddf, path=parquet_out_dir)
        context.log_result('parquet directory', parquet_out_dir)

### Convert the code to MLRun function

Use code_to_function to convert the code to MLRun <br>

In [None]:
fn = mlrun.code_to_function(name="snowflake-dask",  
                            kind='job', 
                            filename='snowflake_dask.py',
                            image='mlrun/mlrun',
                            requirements='requirements.txt',
                            handler="load_results", 
                            description="Snowflake Dask - Ingest snowflake data in parallel with Dask cluster",
                            categories=["data-prep"],
                            labels={"author": "xingsheng"}
                           )
fn.apply(mlrun.platforms.auto_mount())
fn.deploy()

#### export function to local `function.yaml` file for testing
in the real usage, we will import a function from hub

In [None]:
fn.export('function.yaml')
# print(fn.to_yaml())

#### import a function from local `function.yaml' for testing (Need to change it to import from hub before PR)

In [None]:
fn = mlrun.import_function("./function.yaml")

In [None]:
# fn = mlrun.import_function("hub://snowflake_dask")

In [None]:
fn.apply(mlrun.platforms.auto_mount()) # this is a very important line

#### create a dask cluster and specify the configuration for the dask process (e.g. replicas, memory etc)

In [None]:
# function URI is db://<project>/<name>
dask_uri = f'db://{project_name}/{dask_cluster_name}'
dask_uri

In [None]:
dsf = mlrun.new_function(name=dask_cluster_name, 
                         kind='dask', 
                         image='mlrun/mlrun',
                         requirements=["bokeh", "snowflake-connector-python[pandas]"]
                        )
dsf.apply(mlrun.mount_v3io())
dsf.spec.remote = True
dsf.spec.min_replicas = 1
dsf.spec.max_replicas = 10
dsf.spec.service_type = "NodePort"
dsf.with_requests(mem='4G', cpu='2')
# dsf.spec.node_port=30088
# dsf.spec.scheduler_timeout = "5 days"

In [None]:
dsf.deploy()

In [None]:
client = dsf.client

### Run the function

When running the function you would see a remote dashboard link as part of the result. click on this link takes you to the dask monitoring dashboard

In [None]:
p = 'my-local-test'
parquet_path = f"/v3io/bigdata/pq_from_sf_dask/{p}"

fn.run(handler = 'load_results',
       params={"dask_client": dask_uri, 
               "connection_info": connection_info, 
               "query": "SELECT * FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.CUSTOMER",
               "parquet_out_dir": parquet_path,
               "publish_name": "customer",
              }
      )

In [None]:
client.close()

## Track the progress in the UI

Users can view the progress and detailed information in the mlrun UI by clicking on the uid above. <br>
Also, to track the dask progress in the dask UI click on the "dashboard link" above the "client" section