In [1]:
import mlrun



In [2]:
%nuclio config kind = "job"
%nuclio config spec.image = "mlrun/ml-models"

%nuclio: setting kind to 'job'
%nuclio: setting spec.image to 'mlrun/ml-models'


In [3]:
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

In [4]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import dask
import dask.dataframe as dd
from dask.distributed import Client

from mlrun.artifacts import PlotArtifact, TableArtifact
from mlrun.mlutils.plots import gcf_clear

from yellowbrick import ClassBalance

from typing import List

In [5]:
pd.set_option("display.float_format", lambda x: "%.2f" % x)

def summarize(
    context,
    dask_key: str = "dask_key",
    dataset: mlrun.DataItem = None,
    label_column: str = "label",
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """Summarize a table
    
    Connects to dask client through the function context, or through an optional
    user-supplied scheduler.

    :param context:         the function context
    :param dask_key:        key of dataframe in dask client "datasets" attribute
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param dask_function:   dask function url (db://..)
    :param dask_client:     dask client object
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError('dask client was not provided')
        
    if dask_key in client.datasets:
        table = client.get_dataset(dask_key)
    elif dataset:
        table = dataset.as_df(df_module=dd)
    else:
        context.logger.info(f"only these datasets are available {client.datasets} in client {client}")
        raise Exception("dataset not found on dask cluster")
    header = table.columns.values
    
    gcf_clear(plt)
    table = table.compute()
    snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5})
    context.log_artifact(PlotArtifact('histograms',  body=plt.gcf()), 
                         local_path=f"{plots_dest}/hist.html")

    gcf_clear(plt)   
    labels = table.pop(label_column)
    if not class_labels:
        class_labels = labels.unique()
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)   
    scale_pos_weight = class_balance_model.support_[0]/class_balance_model.support_[1]
    context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), 
                         local_path=f"{plots_dest}/imbalance.html")
    
    gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    context.log_artifact(PlotArtifact("correlation",  body=plt.gcf()), 
                         local_path=f"{plots_dest}/corr.html")
    # otherwise shows last plot:
    gcf_clear(plt)

In [6]:
# nuclio: end-code

### save

In [20]:
# create job function object from notebook code
mlrun.set_environment(artifact_path='./')
fn = mlrun.code_to_function('describe_dask', handler='summarize', 
                            description="describe and visualizes dataset stats", 
                            categories=["analysis"], 
                            labels={'author': 'yjb'},
                            code_output='.')

fn.export()

> 2021-01-26 12:27:25,056 [info] function spec saved to path: function.yaml


<mlrun.runtimes.kubejob.KubejobRuntime at 0x7f50dc189610>

## tests

In [11]:
fn.apply(mlrun.platforms.auto_mount())
DATA_URL = "/User/iris.csv"

In [None]:
!curl -L "https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv" > {DATA_URL}

In [18]:
# create a dask test cluster (dask function)
dask_cluster = mlrun.new_function('dask_tests', kind='dask', image='mlrun/ml-models')
dask_cluster.apply(mlrun.mount_v3io())
dask_cluster.spec.remote = True
dask_cluster.with_requests(mem='2G')
dask_cluster.save()

'259335271b01d4b92839f469ca7417be23bce2af'

In [19]:
run = fn.run(name="tasks-describe", 
             handler=summarize, 
             inputs={"dataset": DATA_URL}, 
             params={'label_column': 'label', "dask_function": 'db://default/dask_tests'})

> 2021-01-26 11:19:23,050 [info] starting run tasks-describe uid=d73b134369ce497184425a14fbd43ba7 DB=http://mlrun-api:8080
> 2021-01-26 11:19:23,203 [info] Job is running in the background, pod: tasks-describe-nq7vj
> 2021-01-26 11:19:28,467 [info] using in-cluster config.
> 2021-01-26 11:19:33,713 [info] to get a dashboard link, use NodePort service_type
> 2021-01-26 11:19:33,713 [info] trying dask client at: tcp://mlrun-dask-tests-9e5bc4f3-1.default-tenant:8786
> 2021-01-26 11:19:33,747 [info] using remote dask scheduler (mlrun-dask-tests-9e5bc4f3-1) at: tcp://mlrun-dask-tests-9e5bc4f3-1.default-tenant:8786
> 2021-01-26 11:19:42,561 [info] run executed, status=completed
final state: completed


project,uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
default,...fbd43ba7,0,Jan 26 11:19:28,completed,tasks-describe,v3io_user=adminkind=jobowner=adminhost=tasks-describe-nq7vj,dataset,label_column=labeldask_function=db://default/dask_tests,scale_pos_weight=1.00,histogramsimbalancecorrelation


to track results use .show() or .logs() or in CLI: 
!mlrun get run d73b134369ce497184425a14fbd43ba7 --project default , !mlrun logs d73b134369ce497184425a14fbd43ba7 --project default
> 2021-01-26 11:19:48,510 [info] run executed, status=completed
