In [None]:
# nuclio: ignore
import nuclio

In [None]:
import json
import os

from cloudpickle import dump, load

from sklearn import metrics
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn import metrics

from typing import List
from mlrun.execution import MLClientCtx
from mlrun.datastore import DataItem
from mlrun.artifacts import PlotArtifact

from utils import get_model_configs

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

def xgboost_train(
    context: MLClientCtx,
    data_key: str,
    sample: int,
    label_column: str,

    # xgboost params:
    num_classes: int = 2,
    class_names: list[str] = ["background", "signal"]
    max_depth: int = 50,
    learning_rate: float = 0.1
    # CHOOSE OBJ BASED ON num_classes objective: 'multi:softmax'
    # some reg params
    
    other_xgb_params: dict = {},
    test_size: float = 0.05,
    train_val_split: float = 0.75,
    rng: int = 1,
    models_dir: str = "models",
    plots_dir: str = "plots",
    score_method: str = "micro",
    

) -> None:
    """train a classifier.

    :param context:           the function context
    :param data_key:          ("raw") name of raw data file
    :param sample:            Selects the first n rows, or select a sample
                              starting from the first. If negative <-1, select
                              a random sample
    :param label_column:      ground-truth (y) labels
    :param num_classes:
    :param model_key:         ("model") name of model in artifact store,
                              points to a directory
    :param test_size:         (0.05) test set size
    :param train_val_split:   (0.75) Once the test set has been removed the
                              training set gets this proportion.
    :param test_set_key:      store the test data set under this key in the
                              artifact store
    :param rng:               (1) sklearn rng seed
    :param models_dir:        models subfolder on artifact path
    :param plots_dir:         plot subfolder on artifact path
    :param score_method:      for multiclass classification
    :param class_updates:     update these scikit-learn classifier params,
                              input as a dict
    :param model_pkg_class:   user-supplied xgboost json model config
    :param xgb_model:         user-supplied xgboost model (allows training continuation)
    """
    # extract file name from DataItem
    srcfilepath = str(data_key)
    
    # TODO: this should be part of data"s metadata dealt with in another step get a data set, sample, etc...
    # get all data or a sample
    if (sample == -1) or (sample >= 1):
        # get all rows, or contiguous sample starting at row 1.
        raw = pq.read_table(srcfilepath).to_pandas().dropna()
        labels = raw.pop(label_column)
        raw = raw.iloc[:sample, :]
        labels = labels.iloc[:sample]
    else:
        # grab a random sample
        raw = pq.read_table(srcfilepath).to_pandas().dropna().sample(sample * -1)
        labels = raw.pop(label_column)

    # TODO: this should be part of data"s metadata dealt with in another step
    context.header = raw.columns.values
    
    # TODO: all of this should be part of a spitter component that does cv too, dealt with in another step
    # make a hot encode copy of labels before the split
    yb = label_binarize(labels, classes=labels.unique()) # if binary 0/1 labels, will return labels as is
    
    # double split to generate 3 data sets: train, validation and test
    # with xtest,ytest set aside
    # here we hide the binary encoded labels inside the X matrix so that when splitting we preserve order in both the encoded
    # and non-encoded labels:
    x, xtest, y, ytest = train_test_split(np.concatenate([raw, yb], axis=1), labels, test_size=test_size, random_state=rng)
    xtrain, xvalid, ytrain, yvalid = train_test_split(x, y, train_size=train_val_split, random_state=rng)
    # now extract the hot_encoded labels
    ytrainb = xtrain[:, -yb.shape[1]:].copy()
    xtrain = xtrain[:, :-yb.shape[1]].copy()
    # extract the hot_encoded labels
    yvalidb = xvalid[:, -yb.shape[1]:].copy()
    xvalid = xvalid[:, :-yb.shape[1]].copy()
    # extract the hot_encoded labels
    ytestb = xtest[:, -yb.shape[1]:].copy()
    xtest = xtest[:, :-yb.shape[1]].copy()                                      
    
    # set-aside test_set
    test_set = pd.concat(
        [pd.DataFrame(data=xtest, columns=context.header),
         pd.DataFrame(data=ytest.values, columns=[label_column])],
        axis=1,)
    context.log_dataset(test_set_key, df=test_set, format=file_ext, index=False)

    if model_pkg_class.endswith(".json"):
        model_config = json.load(open(model_pkg_class, "r"))
    else:
        # load the model config
        model_config = get_model_configs(model_pkg_class)

    # get update params if any
    if isinstance(class_params_updates, DataItem):
        class_params_updates = json.loads(class_params_updates.get())
    if isinstance(fit_params_updates, DataItem):
        fit_params_updates = json.loads(fit_params_updates.get())
    # update the parameters            
    # add data to fit params
    fit_params_updates.update({"X": xtrain,"y": ytrain.values})
    
    model_config["CLASS"].update(class_params_updates)
    model_config["FIT"].update(fit_params_updates)
    
    # create class and fit
    ClassifierClass = _create_class(model_config["META"]["class"])
    model = ClassifierClass(**model_config["CLASS"])
    model.fit(**model_config["FIT"])

    # save model
    filepath = os.path.join(base_path, f"{models_dir}/{model_key}.pkl")
    try:
        dump(model, open(filepath, "wb"))
        context.log_artifact(model_key, local_path=models_dir)
    except Exception as e:
        print("SERIALIZE MODEL ERROR:", str(e))

    # compute validation metrics
    ypred = model.predict(xvalid)
    y_score = model.predict_proba(xvalid)
    context.logger.info(f"y_score.shape {y_score.shape}")
    context.logger.info(f"yvalidb.shape {yvalidb.shape}")
    if yvalidb.shape[1] > 1:
        # label encoding was applied:
        average_precision = metrics.average_precision_score(yvalidb,
                                                            y_score,
                                                            average=score_method)
        context.log_result(f"rocauc", metrics.roc_auc_score(yvalidb, y_score))
    else:
        average_precision = metrics.average_precision_score(yvalidb,
                                                            y_score[:, 1],
                                                            average=score_method)
        context.log_result(f"rocauc", metrics.roc_auc_score(yvalidb, y_score[:, 1]))
        
    context.log_result(f"avg_precscore", average_precision)
    context.log_result(f"accuracy", float(model.score(xvalid, yvalid)))
    context.log_result(f"f1_score", metrics.f1_score(yvalid, ypred,
                                             average=score_method))

    # TODO: missing validation plots, callbacks need to reintroduced
    
    plot_roc(context, yvalidb, y_score)
    plot_confusion_matrix(context, yvalid, ypred, key="confusion", fmt="png")

def plot_roc(
    context,
    y_labels,
    y_probs,
    key="roc",
    plots_dir: str = "plots",
    fmt="png",
    fpr_label: str = "false positive rate",
    tpr_label: str =  "true positive rate",
    title: str = "roc curve",
    legend_loc: str = "best"
):
    """plot roc curves
    
    TODO:  add averaging method (as string) that was used to create probs, 
    display in legend
    
    :param context:      the function context
    :param y_labels:     ground truth labels, hot encoded for multiclass  
    :param y_probs:      model prediction probabilities
    :param key:          ("roc") key of plot in artifact store
    :param plots_dir:    ("plots") destination folder relative path to artifact path
    :param fmt:          ("png") plot format
    :param fpr_label:    ("false positive rate") x-axis labels
    :param tpr_label:    ("true positive rate") y-axis labels
    :param title:        ("roc curve") title of plot
    :param legend_loc:   ("best") location of plot legend
    """
    # clear matplotlib current figure
    _gcf_clear(plt)
    
    # draw 45 degree line
    plt.plot([0, 1], [0, 1], "k--")
    
    # labelling
    plt.xlabel(fpr_label)
    plt.ylabel(tpr_label)
    plt.title(title)
    plt.legend(loc=legend_loc)
    
    # single ROC or mutliple
    if y_labels.shape[1] > 1:
        # data accummulators by class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(y_labels[:,:-1].shape[1]):
            fpr[i], tpr[i], _ = metrics.roc_curve(y_labels[:, i], y_probs[:, i], pos_label=1)
            roc_auc[i] = metrics.auc(fpr[i], tpr[i])
            plt.plot(fpr[i], tpr[i], label=f"class {i}")
    else:
        fpr, tpr, _ = metrics.roc_curve(y_labels, y_probs[:, 1], pos_label=1)
        plt.plot(fpr, tpr, label=f"positive class")
        
    fname = f"{plots_dir}/{key}.html"
    context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
    

def plot_confusion_matrix(
    context: MLClientCtx,
    labels,
    predictions,
    key: str = "confusion_matrix",
    plots_dir: str = "plots",
    colormap: str = "Blues",
    fmt: str = "png",
    sample_weight=None
):
    """Create a confusion matrix.
    Plot and save a confusion matrix using test data from a
    modelline step.
    
    See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
    
    TODO: fix label alignment
    TODO: consider using another packaged version
    TODO: refactor to take params dict for plot options

    :param context:         function context
    :param labels:          validation data ground-truth labels
    :param predictions:     validation data predictions
    :param key:             str
    :param plots_dir:       relative path of plots in artifact store
    :param colormap:        colourmap for confusion matrix
    :param fmt:             plot format
    :param sample_weight:   sample weights
    """
    _gcf_clear(plt)
    
    cm = metrics.confusion_matrix(labels, predictions, sample_weight=None)
    sns.heatmap(cm, annot=True, cmap=colormap, square=True)

    fig = plt.gcf()
    fname = f"{plots_dir}/{key}.html"
    context.log_artifact(PlotArtifact(key, body=fig), local_path=fname)

    
def _gcf_clear(plt):
    """Utility to clear matplotlib figure

    Run this inside every plot method before calling any matplotlib
    methods

    :param plot:    matloblib figure object
    """
    plt.cla()
    plt.clf()
    plt.close()

In [None]:
# nuclio: end-code

### save

In [None]:
from mlrun import code_to_function 
# create job function object from notebook code
fn = code_to_function("xgboost", kind="job", with_doc=True,
                      handler=train_model, image="mlrun/ml-models")

# add metadata (for templates and reuse)
fn.spec.default_handler = "xgboost"
fn.spec.description = "train an xgboost modedl"
fn.metadata.categories = ["models", "xgboost"]
fn.spec.image_pull_policy = "Always"
fn.metadata.labels = {"author": "yjb"}

fn.save()
fn.export("function.yaml")

### test

In [None]:
from mlrun import import_function, mount_v3io

func = import_function("hub://xgboost").apply(mount_v3io())
# func = import_function("function.yaml").apply(mlrun.mount_v3io())

In [None]:
# change any scikit-learn class params here (__init__ funciton params)
class_params_updates = {
    "random_state" : 1
}

# change any scikit-learn fit params here
fit_params_updates = {}

In [None]:
task_params = {
    "name" : "tasks train an xgboost model",
    "params" : {
        "model_key"       : "models",
        
        # POINT THIS TO YOUR DATA
        #"data_key"        : "/User/artifacts/iris.parquet",
        #"data_key"        : "/User/artifacts/wine.parquet",
        "data_key"        : "/User/artifacts/breast_cancer.parquet",
        "sample"          : -1,
        "label_column"    : "labels",
        "test_size"       : 0.10,
        "train_val_split" : 0.75,
        "rng"             : 1,
        
        # xgboost parameters
        "max_depth"
        
        # other xgboost paramters
        "xgboost_params"   : class_params_updates}


from mlrun import NewTask
run = func.run(NewTask(**task_params), artifact_path="/User/artifacts")