In [1]:
import mlrun
mlrun.mlconf.mlrundb = 'https://mlrun-api:8080'

## describe a table

In [12]:
# nuclio: ignore
import nuclio

In [21]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from mlrun.execution import MLClientCtx
from mlrun.datastore import DataItem
from mlrun.artifacts import PlotArtifact, TableArtifact

from sklearn.preprocessing import StandardScaler
from yellowbrick import ClassBalance

from typing import IO, AnyStr, Union, List, Optional

pd.set_option("display.float_format", lambda x: "%.2f" % x)

def _gcf_clear(plt):
    plt.cla()
    plt.clf()
    plt.close() 

def describe(
    context: MLClientCtx,
    table: Union[DataItem, str],
    label_column: str,
    class_labels: List[str] = None,
    key: str = "table-summary",
    plot_hist: bool = True,
    plots_dest: str = 'plots'
) -> None:
    """Summarize a table

    TODO: merge with dask version

    :param context:         the function context
    :param table:           pandas dataframe
    :param label_column:    ground truth column label
    :param key:             key of table summary in artifact store
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    """
    base_path = context.artifact_path
    os.makedirs(base_path, exist_ok=True)
    os.makedirs(base_path+"/plots", exist_ok=True)
    
    print(f'TABLE {table}')
    table = pd.read_parquet(str(table))
    header = table.columns.values

    # plot histogram
    _gcf_clear(plt)
    snsplt = sns.pairplot(table, hue=label_column)
    snsplt.savefig(os.path.join(base_path, f"{plots_dest}/hist.png"))
    context.log_artifact(PlotArtifact("histograms",  body=plt.gcf()), local_path=f"{plots_dest}/hist.html")
    
    # describe table
    sumtbl = table.describe()
    sumtbl = sumtbl.append(len(table.index)-table.count(), ignore_index=True)
    sumtbl.insert(0, "metric", ["count", "mean", "std", "min","25%", "50%", "75%", "max", "nans"])
    
    sumtbl.to_csv(os.path.join(base_path, key+".csv"), index=False)
    context.log_artifact(key, local_path=key+".csv")

    # plot class balance, record relative class weight
    _gcf_clear(plt)
    
    labels = table.pop(label_column)
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)
    
    scale_pos_weight = class_balance_model.support_[0]/class_balance_model.support_[1]
    context.log_artifact("scale_pos_weight", f"{scale_pos_weight:0.2f}")

    class_balance_model.show(outpath=os.path.join(base_path, f"{plots_dest}/imbalance.png"))
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html")
    
    # plot feature correlation
    _gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    plt.savefig(os.path.join(base_path, f"{plots_dest}/corr.png"))
    context.log_artifact(PlotArtifact("correlation",  body=plt.gcf()), local_path=f"{plots_dest}corr.html")
    
   
    _gcf_clear(plt)

In [22]:
# nuclio: end-code

## tests

### local

the following install may be needed when running on your local notebook:

    !python -m pip install yellowbrick

When run remotely this will not be necessary since the `yellowbrick` package is pre-installed on the function's job image.

In [23]:
fn = mlrun.code_to_function('summary', kind='job', image='mlrun/ml-models:0.4.5')
fn.spec.default_handler = 'describe'

In [24]:
fn.apply(mlrun.mount_v3io())

<mlrun.runtimes.kubejob.KubejobRuntime at 0x7fd810adf5f8>

In [25]:
summ_task = mlrun.NewTask(
    "sum", 
    params={"key": "summary", "label_column": "species"},
    inputs={"table": '/User/functions/describe/iris.pqt'},
    artifact_path='/User/functions/describe/artifacts/{{run.uid}}')

In [26]:
SUM_RUN = fn.run(summ_task)

[mlrun] 2020-03-15 19:48:08,771 starting run sum uid=0e70c48c2b0047239ab007a69f89d57c  -> http://10.194.253.77:8080
[mlrun] 2020-03-15 19:48:08,846 Job is running in the background, pod: sum-hpb7z
Intel(R) Data Analytics Acceleration Library (Intel(R) DAAL) solvers for sklearn enabled: https://intelpython.github.io/daal4py/sklearn.html
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
TABLE /User/functions/describe/iris.pqt
[mlrun] 2020-03-15 19:48:19,115 log artifact histograms at /User/functions/describe/artifacts/0e70c48c2b0047239ab007a69f89d57c/plots/hist.html, size: 288817, db: Y
[mlrun] 2020-03-15 19:48:19,157 log artifact summary at /User/functions/describe/artifacts/0e70c48c2b0047239ab007a69f89d57c/summary.csv, size: None, d

uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
...89d57c,0,Mar 15 19:48:13,completed,sum,host=sum-hpb7zkind=jobowner=admin,table,key=summarylabel_column=species,,histogramssummaryscale_pos_weightimbalancecorrelation


to track results use .show() or .logs() or in CLI: 
!mlrun get run 0e70c48c2b0047239ab007a69f89d57c  , !mlrun logs 0e70c48c2b0047239ab007a69f89d57c 
[mlrun] 2020-03-15 19:48:24,153 run executed, status=completed
