# a train task

Please uncomment the model type `MY_MODEL_CONFIG` you want to run in the cell below:

1.  **Logistic Regression** using Scikit-Learn defaults, drawn directly from the Scikit-Learn API installed on your workstation
2.  **XGBoost** using a local json model configuration
2.  **LightGBM** using a local json model configuration

Please see 
**[Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn-linear-model-logisticregression)**, **[XGBoost Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html?highlight=parameters#xgboost-parameters)**, and **[LightGBM Parameters](https://lightgbm.readthedocs.io/en/latest/Parameters.html#parameters)** for more details on all configurable parameters.

In [11]:
MY_MODEL_CONFIG = "/User/functions/sklearn_classifier/test/LGBMClassifier.json"
# OR 
# MY_MODEL_CONFIG = "/User/functions/sklearn_classifier/test/XGBClassifier.json"
# OR
# MY_MODEL_CONFIG = "sklearn.linear_model.LogisticRegression"

If you want to change parameters to either class `__init__` or `fit` functions, use the parameters `class_params_updates` and `fit_params_updates` in the cell below:

In [12]:
class_params_updates = {
    "random_state" : 1
}
fit_params_updates = {}

In [13]:
import mlrun, os, json
mlrun.mlconf.dbpath = "http://mlrun-api:8080"

FUNCTION = "sklearn_classifier"

train_params = {
    "name" : "train",
    "params" : {
        
        # CHOOSE YOUR MODEL
        "model_pkg_class" : MY_MODEL_CONFIG,
        "model_key"       : "models",
        
        # POINT THIS TO YOUR DATA
        "data_key"        : "/User/functions/load_dataset/iris.pqt", # multiclass, small dataset
        #"data_key"        : "/User/functions/arc_to_parquet/artifacts/higgs.pqt",  # binary, large data
        "sample"          : -1,
        "label_column"    : "labels",
        "test_size"       : 0.10,
        "train_val_split" : 0.75,
        "rng"             : 1,
    
        # CHANGE ANY class params here
        "class_params_updates"   : class_params_updates,
        
        # CHANGE ANY fit params here
        "fit_params_updates"    : fit_params_updates},
    
    "artifact_path" : f"/User/functions/{FUNCTION}/artifacts"}

train_fn = mlrun.import_function(f"/User/functions/{FUNCTION}/function.yaml").apply(mlrun.mount_v3io())
tf = train_fn.run(**train_params)

[mlrun] 2020-03-22 20:57:14,467 starting run train uid=5238b770c04b480f8d71b1b20586e482  -> http://mlrun-api:8080
[mlrun] 2020-03-22 20:57:14,606 Job is running in the background, pod: train-wzfch
Intel(R) Data Analytics Acceleration Library (Intel(R) DAAL) solvers for sklearn enabled: https://intelpython.github.io/daal4py/sklearn.html
No handles with labels found to put in legend.
[mlrun] 2020-03-22 20:57:27,753 log artifact test_set at /User/functions/sklearn_classifier/artifacts/test_set.pqt, size: None, db: Y
[mlrun] 2020-03-22 20:57:27,913 log artifact models at /User/functions/sklearn_classifier/artifacts/models, size: None, db: Y
[mlrun] 2020-03-22 20:57:27,914 y_score.shape (34, 3)
[mlrun] 2020-03-22 20:57:27,914 yvalidb.shape (34, 4)
[mlrun] 2020-03-22 20:57:28,143 log artifact roc at /User/functions/sklearn_classifier/artifacts/plots/roc.png, size: 32470, db: Y
[mlrun] 2020-03-22 20:57:28,395 log artifact confusion at /User/functions/sklearn_classifier/artifacts/plots/confusi

uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
...86e482,0,Mar 22 20:57:27,completed,train,host=train-wzfchkind=jobowner=admin,,"class_params_updates={'random_state': 1}data_key=/User/functions/load_dataset/iris.pqtfit_params_updates={'X': [[5.2, 4.1, 1.5, 0.1], [5.3, 3.7, 1.5, 0.2], [4.7, 3.2, 1.6, 0.2], [6.3, 3.3, 6.0, 2.5], [5.1, 3.5, 1.4, 0.2], [6.7, 3.3, 5.7, 2.5], [6.7, 2.5, 5.8, 1.8], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 3.9, 1.2], [6.4, 2.8, 5.6, 2.1], [6.1, 2.6, 5.6, 1.4], [5.7, 2.8, 4.5, 1.3], [5.6, 2.9, 3.6, 1.3], [4.6, 3.1, 1.5, 0.2], [4.7, 3.2, 1.3, 0.2], [6.0, 2.2, 4.0, 1.0], [6.2, 2.2, 4.5, 1.5], [6.2, 2.8, 4.8, 1.8], [5.1, 3.3, 1.7, 0.5], [5.0, 3.5, 1.6, 0.6], [5.9, 3.2, 4.8, 1.8], [4.9, 3.0, 1.4, 0.2], [6.5, 3.0, 5.2, 2.0], [4.9, 3.1, 1.5, 0.1], [5.1, 3.8, 1.5, 0.3], [6.5, 3.0, 5.5, 1.8], [5.1, 3.4, 1.5, 0.2], [5.1, 3.7, 1.5, 0.4], [6.0, 3.4, 4.5, 1.6], [5.2, 3.4, 1.4, 0.2], [5.0, 2.0, 3.5, 1.0], [4.9, 3.6, 1.4, 0.1], [6.6, 3.0, 4.4, 1.4], [5.1, 3.5, 1.4, 0.3], [6.3, 2.7, 4.9, 1.8], [5.6, 3.0, 4.1, 1.3], [4.9, 2.5, 4.5, 1.7], [5.8, 2.7, 4.1, 1.0], [6.3, 2.3, 4.4, 1.3], [7.0, 3.2, 4.7, 1.4], [5.7, 2.6, 3.5, 1.0], [5.4, 3.4, 1.7, 0.2], [5.4, 3.7, 1.5, 0.2], [6.1, 3.0, 4.9, 1.8], [5.6, 2.7, 4.2, 1.3], [7.2, 3.6, 6.1, 2.5], [5.6, 2.5, 3.9, 1.1], [5.8, 2.7, 5.1, 1.9], [7.9, 3.8, 6.4, 2.0], [5.7, 2.5, 5.0, 2.0], [6.1, 2.9, 4.7, 1.4], [6.8, 2.8, 4.8, 1.4], [6.3, 2.5, 5.0, 1.9], [6.3, 2.9, 5.6, 1.8], [7.3, 2.9, 6.3, 1.8], [6.5, 3.2, 5.1, 2.0], [5.6, 3.0, 4.5, 1.5], [6.4, 3.1, 5.5, 1.8], [4.8, 3.0, 1.4, 0.3], [6.7, 3.1, 5.6, 2.4], [6.2, 2.9, 4.3, 1.3], [5.5, 2.4, 3.7, 1.0], [6.9, 3.1, 4.9, 1.5], [7.4, 2.8, 6.1, 1.9], [6.9, 3.1, 5.4, 2.1], [6.9, 3.2, 5.7, 2.3], [5.5, 4.2, 1.4, 0.2], [6.1, 2.8, 4.7, 1.2], [5.1, 3.8, 1.6, 0.2], [7.7, 3.8, 6.7, 2.2], [5.8, 2.6, 4.0, 1.2], [7.7, 2.8, 6.7, 2.0], [5.4, 3.4, 1.5, 0.4], [6.4, 2.7, 5.3, 1.9], [4.9, 2.4, 3.3, 1.0], [5.5, 2.6, 4.4, 1.2], [4.5, 2.3, 1.3, 0.3], [5.0, 3.3, 1.4, 0.2], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [5.9, 3.0, 4.2, 1.5], [5.5, 2.4, 3.8, 1.1], [4.4, 3.0, 1.3, 0.2], [7.2, 3.0, 5.8, 1.6], [7.1, 3.0, 5.9, 2.1], [5.0, 3.6, 1.4, 0.2], [5.7, 3.8, 1.7, 0.3], [4.3, 3.0, 1.1, 0.1], [6.4, 3.2, 4.5, 1.5], [5.7, 4.4, 1.5, 0.4], [6.3, 3.4, 5.6, 2.4], [4.8, 3.1, 1.6, 0.2], [5.0, 3.5, 1.3, 0.3], [4.6, 3.4, 1.4, 0.3], [6.8, 3.2, 5.9, 2.3], [6.2, 3.4, 5.4, 2.3], [6.3, 2.8, 5.1, 1.5], [6.4, 2.8, 5.6, 2.2], [5.2, 3.5, 1.5, 0.2], [7.2, 3.2, 6.0, 1.8]], 'y': [0.0, 0.0, 0.0, 2.0, 0.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 0.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0, 1.0, 0.0, 2.0, 1.0, 2.0, 0.0, 2.0, 1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 1.0, 0.0, 2.0, 2.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0]}label_column=labelsmodel_key=modelsmodel_pkg_class=/User/functions/sklearn_classifier/test/LGBMClassifier.jsonrng=1sample=-1test_size=0.1train_val_split=0.75",accuracy=0.9411764705882353avg_precscore=0.9719080049216138f1_score=0.9411764705882353rocauc=0.9712373737373737,test_setmodelsrocconfusion


to track results use .show() or .logs() or in CLI: 
!mlrun get run 5238b770c04b480f8d71b1b20586e482  , !mlrun logs 5238b770c04b480f8d71b1b20586e482 
[mlrun] 2020-03-22 20:57:33,846 run executed, status=completed


In [14]:
train_fn.run?

[0;31mSignature:[0m
[0mtrain_fn[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mrunspec[0m[0;34m:[0m [0mmlrun[0m[0;34m.[0m[0mmodel[0m[0;34m.[0m[0mRunObject[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhandler[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mname[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mproject[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparams[0m[0;34m:[0m [0mdict[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minputs[0m[0;34m:[0m [0mdict[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mout_path[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mworkdir[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0martifact_path[0m[0