In [1]:
# nuclio: ignore
import nuclio

In [2]:
%nuclio config kind = "job"
%nuclio config spec.image = "mlrun/ml-models"

%nuclio: setting kind to 'job'
%nuclio: setting spec.image to 'mlrun/ml-models'


In [3]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [4]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from mlrun.execution import MLClientCtx
from mlrun.datastore import DataItem
from mlrun.artifacts import PlotArtifact, TableArtifact
from mlrun.mlutils import gcf_clear

from typing import List

In [14]:
pd.set_option("display.float_format", lambda x: "%.2f" % x)

def summarize(
    context: MLClientCtx,
    table: DataItem,
    label_column: str = None,
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    update_dataset = False,
) -> None:
    """Summarize a table

    :param context:         the function context
    :param table:           MLRun input pointing to pandas dataframe (csv/parquet file path)
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param update_dataset:  when the table is a registered dataset update the charts in-place 
    """
    df = table.as_df()
    header = df.columns.values
    extra_data = {}
    
    try:
        gcf_clear(plt)
        snsplt = sns.pairplot(df, hue=label_column)#, diag_kws={"bw": 1.5})
        extra_data["histograms"] = context.log_artifact(PlotArtifact("histograms",  body=plt.gcf()),
                                                        local_path=f"{plots_dest}/hist.html", db_key=False)
    except Exception as e:
        context.logger.error(f'Failed to create pairplot histograms due to: {e}')
    
    try:
        gcf_clear(plt)
        plot_cols = 3
        plot_rows = int((len(header) - 1) / plot_cols)+1
        fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4))
        fig.tight_layout(pad=2.0)
        for i in range(plot_rows * plot_cols):
            if i < len(header):
                sns.violinplot(x=df[header[i]], ax=ax[int(i / plot_cols)][i % plot_cols], 
                            orient='h', width=0.7, inner="quartile")
            else:
                fig.delaxes(ax[int(i / plot_cols)][i % plot_cols])        
            i+=1
        extra_data["violin"] = context.log_artifact(PlotArtifact("violin",  body=plt.gcf(), title='Violin Plot'),
                                                    local_path=f"{plots_dest}/violin.html", db_key=False)
    except Exception as e:
        context.logger.warn(f'Failed to create violin distribution plots due to: {e}')

    if label_column: 
        labels = df.pop(label_column)
        imbtable = labels.value_counts(normalize=True).sort_index()
        try:
            gcf_clear(plt)  
            balancebar = imbtable.plot(kind='bar', title='class imbalance - labels')
            balancebar.set_xlabel('class')
            balancebar.set_ylabel("proportion of total")
            extra_data["imbalance"] = context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), 
                                                        local_path=f"{plots_dest}/imbalance.html")
        except Exception as e:
            context.logger.warn(f'Failed to create class imbalance plot due to: {e}')
        context.log_artifact(TableArtifact("imbalance-weights-vec", 
                                           df=pd.DataFrame({"weights": imbtable})),
                             local_path=f"{plots_dest}/imbalance-weights-vec.csv", db_key=False)

    tblcorr = df.corr()
    mask = np.zeros_like(tblcorr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True
    
    dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header)
    dfcorr = dfcorr[np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])]
    context.log_artifact(TableArtifact("correlation-matrix", df=tblcorr, visible=True), 
                         local_path=f"{plots_dest}/correlation-matrix.csv", db_key=False)
    
    try:
        gcf_clear(plt)
        ax = plt.axes()
        sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds)
        ax.set_title("features correlation")
        extra_data["correlation"] = context.log_artifact(PlotArtifact("correlation",  body=plt.gcf(), title='Correlation Matrix'),
                                                        local_path=f"{plots_dest}/corr.html", db_key=False)
    except Exception as e:
            context.logger.warn(f'Failed to create features correlation plot due to: {e}')
    

    gcf_clear(plt)
    if update_dataset and table.meta and table.meta.kind == 'dataset':
        from mlrun.artifacts import update_dataset_meta
        update_dataset_meta(table.meta, extra_data=extra_data)
        

In [15]:
# nuclio: end-code

### mlconfig

In [16]:
from mlrun import mlconf
import os
mlconf.dbpath = mlconf.dbpath or 'http://mlrun-api:8080'
mlconf.artifact_path = mlconf.artifact_path or os.path.abspath('./')

### save

In [18]:
from mlrun import code_to_function 
# create job function object from notebook code
fn = code_to_function("describe", handler="summarize",
                      description="describe and visualizes dataset stats",
                      categories=["analysis"],
                      labels = {"author": "yjb"},
                      code_output='.')

fn.export()

> 2020-07-23 07:46:39,543 [info] function spec saved to path: function.yaml


<mlrun.runtimes.kubejob.KubejobRuntime at 0x7f56f9032b38>

## tests

In [9]:
from mlrun.platforms import auto_mount
fn.apply(auto_mount())

<mlrun.runtimes.kubejob.KubejobRuntime at 0x7f57271d8278>

In [10]:
from mlrun import NewTask, run_local

DATA_URL = 'https://raw.githubusercontent.com/mlrun/functions/master/describe/iris_dataset.csv'

In [11]:
task = NewTask(
    name="tasks-describe", 
    handler=summarize, 
    inputs={"table": DATA_URL}, params={'update_dataset': True, 'label_column': 'label'})

### run locally

In [12]:
run = run_local(task)

> 2020-07-22 09:00:32,582 [debug] Validating field against patterns: {'field_name': 'run.metadata.name', 'field_value': 'tasks-describe', 'pattern': ['^.{0,63}$', '^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$']}
> 2020-07-22 09:00:32,598 [info] starting run tasks-describe uid=f30656601819462c892a9365dd175f72  -> http://mlrun-api:8080
> 2020-07-22 09:00:37,475 [debug] log artifact histograms at /User/functions/describe/plots/hist.html, size: 140127, db: N
> 2020-07-22 09:00:38,377 [debug] log artifact violin at /User/functions/describe/plots/violin.html, size: 54096, db: N
> 2020-07-22 09:00:38,680 [debug] log artifact imbalance at /User/functions/describe/plots/imbalance.html, size: 10045, db: Y
> 2020-07-22 09:00:38,697 [debug] log artifact imbalance-weights-vec at /User/functions/describe/plots/imbalance-weights-vec.csv, size: 65, db: N
> 2020-07-22 09:00:38,702 [debug] log artifact correlation-matrix at /User/functions/describe/plots/correlation-matrix.csv, size: 324, db: N
> 2020-

project,uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
default,...dd175f72,0,Jul 22 09:00:32,completed,tasks-describe,v3io_user=adminkind=handlerowner=adminhost=jupyter-558bf7fbc8-kt6x9,table,update_dataset=Truelabel_column=label,,histogramsviolinimbalanceimbalance-weights-veccorrelation-matrixcorrelation


to track results use .show() or .logs() or in CLI: 
!mlrun get run f30656601819462c892a9365dd175f72 --project default , !mlrun logs f30656601819462c892a9365dd175f72 --project default
> 2020-07-22 09:00:39,121 [info] run executed, status=completed


### run remotely

In [13]:
fn.run(task, inputs={"table": DATA_URL})

> 2020-07-22 09:00:39,154 [debug] Validating field against patterns: {'field_name': 'run.metadata.name', 'field_value': 'tasks-describe', 'pattern': ['^.{0,63}$', '^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$']}
> 2020-07-22 09:00:39,161 [info] starting run tasks-describe uid=d8edd5e4b8004437927f0810d2ad1658  -> http://mlrun-api:8080
> 2020-07-22 09:00:39,287 [info] Job is running in the background, pod: tasks-describe-vmgv8
> 2020-07-22 09:00:45,175 [debug] Validating field against patterns: {'field_name': 'run.metadata.name', 'field_value': 'tasks-describe', 'pattern': ['^.{0,63}$', '^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$']}
> 2020-07-22 09:00:45,291 [debug] starting local run: main.py # summarize
> 2020-07-22 09:00:50,598 [debug] log artifact histograms at /User/functions/describe/plots/hist.html, size: 238319, db: N
> 2020-07-22 09:00:51,652 [debug] log artifact violin at /User/functions/describe/plots/violin.html, size: 86708, db: N
> 2020-07-22 09:00:51,908 [debug] log ar

project,uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
default,...d2ad1658,0,Jul 22 09:00:46,completed,tasks-describe,v3io_user=adminkind=jobowner=adminhost=tasks-describe-vmgv8,table,update_dataset=Truelabel_column=label,,histogramsviolinimbalanceimbalance-weights-veccorrelation-matrixcorrelation


to track results use .show() or .logs() or in CLI: 
!mlrun get run d8edd5e4b8004437927f0810d2ad1658 --project default , !mlrun logs d8edd5e4b8004437927f0810d2ad1658 --project default
> 2020-07-22 09:00:54,614 [info] run executed, status=completed


<mlrun.model.RunObject at 0x7f5725802d30>