In [None]:
# nuclio: ignore
import nuclio

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from mlrun.execution import MLClientCtx
from mlrun.datastore import DataItem
from mlrun.artifacts import PlotArtifact, TableArtifact

from sklearn.preprocessing import StandardScaler
from yellowbrick import ClassBalance

from typing import List

In [None]:
pd.set_option("display.float_format", lambda x: "%.2f" % x)

def _gcf_clear(plt):
    plt.cla()
    plt.clf()
    plt.close() 

def summarize(
    context: MLClientCtx,
    table: str,
    label_column: str = 'labels',
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = 'plots'
) -> None:
    """Summarize a table

    :param context:         the function context
    :param table:           pandas dataframe (csv/parquet file path)
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    """
    table = str(table)
    if table.endswith('.csv'):
        table = pd.read_csv(table)
    else: 
        table = pd.read_parquet(table)
    header = table.columns.values
    
    _gcf_clear(plt)
# REMOVED: ax was not defined
#     try:
#         snsplt = sns.pairplot(table, hue=label_column, ax=ax)
#     except Exception as e:
    snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5})
    context.log_artifact(PlotArtifact('histograms',  body=plt.gcf()), local_path=f"{plots_dest}/hist.html")

    _gcf_clear(plt)   
    labels = table.pop(label_column)
    if not class_labels:
        class_labels = labels.unique()
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)   
    scale_pos_weight = class_balance_model.support_[0]/class_balance_model.support_[1]
    context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html")
    
    _gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    context.log_artifact(PlotArtifact("correlation",  body=plt.gcf()), local_path=f"{plots_dest}/corr.html")
    # otherwise shows last plot:
    _gcf_clear(plt)
    print(context.artifacts)

In [None]:
# nuclio: end-code

In [None]:
from mlrun import run_local, code_to_function , NewTask, mlconf

mlconf.dbpath = "http://mlrun-api:8080"
mlconf.artifact_path = '/User/artifacts'

In [None]:
# create job function object from notebook code
fn = code_to_function('describe', kind='job', with_doc=True,
                      handler=summarize, image='mlrun/ml-models')

# add metadata (for templates and reuse)
fn.spec.default_handler = 'summarize'
fn.spec.description = "describe and visualizes dataset stats"
fn.metadata.categories = ['models', 'visualization']
fn.metadata.labels = {'author': 'yjb'}

fn.save()
fn.export('function.yaml')

In [None]:
table_path = os.path.join(mlconf.artifact_path, "iris.parquet")
task = NewTask(name="tasks describe", handler=summarize, inputs={'table': table_path})
run = run_local(task)

In [None]:
import pandas as pd
df = pd.read_parquet(table_path)
df

In [None]:
from mlrun import import_function, mount_v3io, NewTask

rfn = import_function('hub://describe').apply(mount_v3io())

params = { "name": "tasks describe", "params" : {"table": "/User/artifacts/classifier-data.csv"}}

task_run = rfn.run(NewTask(**params), artifact_path='/User/artifacts')