In [None]:
# nuclio: ignore
import nuclio

In [None]:
%nuclio config kind = "job"
%nuclio config spec.image = "mlrun/ml-models"

In [None]:
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
import os
import pandas as pd
from mlrun.datastore import DataItem
from mlrun.artifacts import get_model
from cloudpickle import load
from mlrun.mlutils import eval_class_model

def cox_test(
    context,
    models_path: DataItem, 
    test_set: DataItem,
    label_column: str,
    plots_dest: str = "plots",
    model_evaluator = None
) -> None:
    """Test one or more classifier models against held-out dataset
    
    Using held-out test features, evaluates the peformance of the estimated model
    
    Can be part of a kubeflow pipeline as a test step that is run post EDA and 
    training/validation cycles
    
    :param context:         the function context
    :param model_file:      model artifact to be tested
    :param test_set:        test features and labels
    :param label_column:    column name for ground truth labels
    :param score_method:    for multiclass classification
    :param plots_dest:      dir for test plots
    :param model_evaluator: WIP: specific method to generate eval, passed in as string
                            or available in this folder
    """  
    xtest = test_set.as_df()
    ytest = xtest.pop(label_column)
    
    model_file, model_obj, _ = get_model(models_path.url, suffix='.pkl')
    model_obj = load(open(str(model_file), "rb"))

    try:
        # there could be different eval_models, type of model (xgboost, tfv1, tfv2...)
        if not model_evaluator:
            # binary and multiclass
            eval_metrics = eval_class_model(context, xtest, ytest, model_obj)

        # just do this inside log_model?
        model_plots = eval_metrics.pop("plots")
        model_tables = eval_metrics.pop("tables")
        for plot in model_plots:
            context.log_artifact(plot, local_path=f"{plots_dest}/{plot.key}.html")
        for tbl in model_tables:
            context.log_artifact(tbl, local_path=f"{plots_dest}/{plot.key}.csv")

        context.log_results(eval_metrics)
    except:
        #dummy log:
        context.log_dataset("cox-test-summary", df=model_obj.summary, index=True, format="csv")
        context.logger.info("cox tester not implemented")

In [None]:
# nuclio: end-code

### mlconfig

In [None]:
from mlrun import mlconf
import os

mlconf.dbpath = mlconf.dbpath or 'http://mlrun-api:8080'
mlconf.artifact_path = mlconf.artifact_path or f'{os.environ["HOME"]}/artifacts'

### save

In [None]:
from mlrun import code_to_function 
# create job function object from notebook code
fn = code_to_function("cox_test")

# add metadata (for templates and reuse)
fn.spec.default_handler = "cox_test"
fn.spec.description = "test a classifier using held-out or new data"
fn.metadata.categories = ["ml", "test"]
fn.metadata.labels = {"author": "yjb", "framework": "survival"}
fn.export("function.yaml")

In [None]:
if "V3IO_HOME" in list(os.environ):
    from mlrun import mount_v3io
    fn.apply(mount_v3io())
else:
    # is you set up mlrun using the instructions at https://github.com/mlrun/mlrun/blob/master/hack/local/README.md
    from mlrun.platforms import mount_pvc
    fn.apply(mount_pvc('nfsvol', 'nfsvol', '/home/jovyan/data'))

In [None]:
task_params = {
    "name" : "tasks cox test",
    "params": {
        "label_column"  : "labels",
        "plots_dest"    : "churn/test/plots"}}

In [None]:
DATA_URL = "https://raw.githubusercontent.com/yjb-ds/testdata/master/data/churn-tests.csv"

In [None]:
from mlrun import run_local, NewTask

run = run_local(NewTask(**task_params),
                handler=cox_test,
                inputs={"test_set": DATA_URL,
                        "models_path"   : "models/cox"},
               workdir=mlconf.artifact_path+"/churn")

In [None]:
run = fn.run(
    NewTask(**task_params),
    inputs={
        "test_set": DATA_URL,
        "models_path"   : "models/cox"},
    workdir=os.path.join(mlconf.artifact_path, "churn"))