In [1]:
# nuclio: ignore
import nuclio

In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from mlrun.execution import MLClientCtx
from mlrun.datastore import DataItem
from mlrun.artifacts import PlotArtifact, TableArtifact

from sklearn.preprocessing import StandardScaler
from yellowbrick import ClassBalance

from typing import IO, AnyStr, Union, List, Optional



In [3]:
pd.set_option("display.float_format", lambda x: "%.2f" % x)

def _gcf_clear(plt):
    plt.cla()
    plt.clf()
    plt.close() 

def summarize(
    context: MLClientCtx,
    table: Union[DataItem, str],
    label_column: str = 'labels',
    class_labels: List[str] = None,
    key: str = "table-summary",
    plot_hist: bool = True,
    plots_dest: str = 'plots'
) -> None:
    """Summarize a table

    :param context:         the function context
    :param table:           pandas dataframe (csv/parquet file path)
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    """
    
    if not table:
        raise ValueError('table input must be specified')
        
    table = str(table)
    suffix = os.path.splitext(table)[1]
    print(suffix)
    if suffix == '.csv':
        table = pd.read_csv(table)
    else: 
        table = pd.read_parquet(str(table))
    header = table.columns.values

    _gcf_clear(plt)
    try:
        snsplt = sns.pairplot(table, hue=label_column, ax=ax)
    except Exception as e:
        snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5})
    context.log_artifact(PlotArtifact('histograms',  body=plt.gcf()), local_path=f"{plots_dest}/hist.html")

    _gcf_clear(plt)   
    labels = table.pop(label_column)
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)   
    scale_pos_weight = class_balance_model.support_[0]/class_balance_model.support_[1]
    context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html")
    
    _gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    context.log_artifact(PlotArtifact("correlation",  body=plt.gcf()), local_path=f"{plots_dest}/corr.html")
    _gcf_clear(plt)

In [4]:
# nuclio: end-code

In [5]:
from mlrun import run_local, NewTask
table_path = '/User/functions/load_dataset/artifacts/iris.parquet'
task = NewTask(handler=summarize, inputs={'table': table_path})
run = run_local(task)

[mlrun] 2020-03-23 14:50:03,797 artifact path is not defined or is local, artifacts will not be visible in the UI
[mlrun] 2020-03-23 14:50:03,798 starting run mlrun-ed34c7-summarize uid=76e62c449b27471888aed10bfacb3c5c  -> 


findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.


.parquet
[mlrun] 2020-03-23 14:50:06,919 log artifact histograms at plots/hist.html, size: 152737, db: N
[mlrun] 2020-03-23 14:50:07,309 log artifact imbalance at plots/imbalance.html, size: 7464, db: N
[mlrun] 2020-03-23 14:50:07,485 log artifact correlation at plots/corr.html, size: 20942, db: N



uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
...cb3c5c,0,Mar 23 14:50:03,completed,mlrun-ed34c7-summarize,kind=handlerowner=adminhost=jupyter-6d69dc994d-h9nv2,table,,scale_pos_weight=1.00,histogramsimbalancecorrelation


to track results use .show() or .logs() or in CLI: 
!mlrun get run 76e62c449b27471888aed10bfacb3c5c --project default , !mlrun logs 76e62c449b27471888aed10bfacb3c5c --project default
[mlrun] 2020-03-23 14:50:07,575 run executed, status=completed


In [6]:
import pandas as pd
df = pd.read_parquet(table_path)
df

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),labels
0,5.10,3.50,1.40,0.20,0.00
1,4.90,3.00,1.40,0.20,0.00
2,4.70,3.20,1.30,0.20,0.00
3,4.60,3.10,1.50,0.20,0.00
4,5.00,3.60,1.40,0.20,0.00
...,...,...,...,...,...
145,6.70,3.00,5.20,2.30,2.00
146,6.30,2.50,5.00,1.90,2.00
147,6.50,3.00,5.20,2.00,2.00
148,6.20,3.40,5.40,2.30,2.00


In [7]:
from mlrun import code_to_function 
# create job function object from notebook code
fn = code_to_function('describe', kind='job', with_doc=True,
                      handler=summarize, image='mlrun/ml-models')

# add metadata (for templates and reuse)
fn.spec.default_handler = 'summarize'
fn.spec.description = "this function visualize dataset stats"
fn.metadata.categories = ['models', 'visualization']
fn.metadata.labels = {'author': 'yjb'}

In [112]:
fn.export('function.yaml')

[mlrun] 2020-03-23 11:07:39,241 function spec saved to path: function.yaml


<mlrun.runtimes.kubejob.KubejobRuntime at 0x7f611d757e10>