In [1]:
# nuclio: ignore
import nuclio

In [None]:
%nuclio config kind = "job"
%nuclio config spec.image = "mlrun/ml-models"

In [2]:
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

In [3]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import dask
import dask.dataframe as dd
from dask.distributed import Client

from mlrun.execution import MLClientCtx
from mlrun.datastore import DataItem
from mlrun.artifacts import PlotArtifact, TableArtifact
from mlrun.mlutils import gcf_clear

from yellowbrick import ClassBalance

from typing import List

In [4]:
pd.set_option("display.float_format", lambda x: "%.2f" % x)

def summarize(
    context: MLClientCtx,
    dask_key: str = "dask_key",
    label_column: str = "labels",
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    alt_scheduler: str = None
) -> None:
    """Summarize a table
    
    Connects to dask client through the function context, or through an optional
    user-supplied scheduler.

    :param context:         the function context
    :param dask_key:        key of dataframe in dask client "datasets" attribute
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param alt_scheduler:   (None) an alternative scheduler file to connect with
    """
    if alt_scheduler:
        dask_client = Client(scheduler_file=str(alt_scheduler))
    elif hasattr(context, "dask_client"):
        dask_client = Client(scheduler_file=str(context.dask_client))
    else:
        raise Exception("out of luck, no dask_client or scheduler file!")
        
    if dask_key in dask_client.datasets:
        table = dask_client.get_dataset(dask_key)
    else:
        context.logger.info(f"only these datasets are available {dask_client.datasets} in client {dask_client}")
        raise Exception("dataset not found on dask cluster")
    header = table.columns.values
    
    gcf_clear(plt)
    table = table.compute()
    snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5})
    context.log_artifact(PlotArtifact('histograms',  body=plt.gcf()), 
                         local_path=f"{plots_dest}/hist.html")

    gcf_clear(plt)   
    labels = table.pop(label_column)
    if not class_labels:
        class_labels = labels.unique()
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)   
    scale_pos_weight = class_balance_model.support_[0]/class_balance_model.support_[1]
    context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), 
                         local_path=f"{plots_dest}/imbalance.html")
    
    gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    context.log_artifact(PlotArtifact("correlation",  body=plt.gcf()), 
                         local_path=f"{plots_dest}/corr.html")
    # otherwise shows last plot:
    gcf_clear(plt)

In [5]:
# nuclio: end-code

### mlconfig

In [13]:
from mlrun import mlconf
import os

mlconf.dbpath = mlconf.dbpath or 'http://mlrun-api:8080'
mlconf.artifact_path = mlconf.artifact_path or f'{os.environ["HOME"]}/artifacts'

### save

In [2]:
from mlrun import code_to_function 
# create job function object from notebook code
fn = code_to_function('describe_dask')

# add metadata (for templates and reuse)
fn.spec.default_handler = 'summarize'
fn.spec.description = "describe and visualizes dataset stats"
fn.metadata.categories = ["analysis"]
fn.metadata.labels = {'author': 'yjb'}
fn.export('function.yaml')

[mlrun] 2020-05-01 22:04:32,007 function spec saved to path: function.yaml


<mlrun.runtimes.kubejob.KubejobRuntime at 0x7fbc52e61e48>

## tests

In [None]:
# load function from marketplacen
from mlrun import import_function

# vcs_branch = 'development'
# base_vcs = f'https://raw.githubusercontent.com/mlrun/functions/{vcs_branch}/'
# mlconf.hub_url = mlconf.hub_url or base_vcs + f'{name}/function.yaml'
# fn = import_function("hub://describe_dask")

In [12]:
from mlrun import import_function, NewTask, run_local

if "V3IO_HOME" in list(os.environ):
    from mlrun import mount_v3io
    fn.apply(mount_v3io())
else:
    # is you set up mlrun using the instructions at https://github.com/mlrun/mlrun/blob/master/hack/local/README.md
    from mlrun.platforms import mount_pvc
    fn.apply(mount_pvc('nfsvol', 'nfsvol', '/home/joyan/data'))

In [13]:
task = NewTask(name="tasks describe dask", 
               inputs={'dask_key': "dask_key",
                       "alt_scheduler" :"/User/artifacts/scheduler.json"})

In [14]:
run = fn.run(task)

[mlrun] 2020-04-30 01:59:21,737 starting run tasks describe dask uid=0cd4c350d24647d98e2cf3c0b7f6cdab  -> http://mlrun-api:8080
[mlrun] 2020-04-30 01:59:21,843 Job is running in the background, pod: tasks-describe-dask-jmmvc
[mlrun] 2020-04-30 01:59:31,254 log artifact histograms at /User/artifacts/plots/hist.html, size: 284413, db: Y
[mlrun] 2020-04-30 01:59:31,876 log artifact imbalance at /User/artifacts/plots/imbalance.html, size: 11716, db: Y
[mlrun] 2020-04-30 01:59:32,048 log artifact correlation at /User/artifacts/plots/corr.html, size: 30642, db: Y

[mlrun] 2020-04-30 01:59:32,118 run executed, status=completed
final state: succeeded


project,uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
default,...b7f6cdab,0,Apr 30 01:59:28,completed,tasks describe dask,host=tasks-describe-dask-jmmvckind=jobowner=adminv3io_user=admin,alt_schedulerdask_key,,scale_pos_weight=1.00,histogramsimbalancecorrelation


to track results use .show() or .logs() or in CLI: 
!mlrun get run 0cd4c350d24647d98e2cf3c0b7f6cdab  , !mlrun logs 0cd4c350d24647d98e2cf3c0b7f6cdab 
[mlrun] 2020-04-30 01:59:34,037 run executed, status=completed
