In [18]:
import joblib
import mlflow
import pandas as pd
import numpy as np
from pickle import dump
from kedro.io import PickleLocalDataSet
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import ExtraTreesClassifier

from sklearn.metrics import accuracy_score, classification_report, f1_score

In [19]:
X_train = pd.read_pickle("data/05_model_input/X_train.pkl")
X_test = pd.read_pickle("data/05_model_input/X_test.pkl")
y_train = pd.read_pickle("data/05_model_input/y_train.pkl")
y_test = pd.read_pickle("data/05_model_input/y_test.pkl")

In [22]:
def run_extra_trees(X_train: PickleLocalDataSet, X_test: PickleLocalDataSet, y_train: PickleLocalDataSet, y_test: PickleLocalDataSet, log=False) -> PickleLocalDataSet:
    model = ExtraTreesClassifier(n_estimators=100, random_state=0, max_depth=20)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    params = {
    }

    f1 = f1_score(y_test, y_pred, average="weighted")

    joblib.dump(model, 'data/06_models/extra_trees.pkl')
    
    
    print(classification_report(y_test, y_pred, digits=5))
    
    if log:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_experiment(context.params.databricks)

        run_id = mlflow.search_runs(experiment_ids="3889491181315524", filter_string="tags.`mlflow.runName`='extra_trees'", run_view_type=1)["run_id"][0]    
        mlflow.start_run(run_id=run_id, nested=False)
#         mlflow.start_run(run_name="extra_trees", nested=False)

        with mlflow.start_run(nested=True):
            mlflow.set_tags({
                "lib": "sklearn",
                "features": X_train.columns.values,
            })

            mlflow.log_params(params)
            mlflow.log_metric("f1", f1, 1)
            mlflow.log_artifact('data/05_model_input/X_test.pkl')
        mlflow.end_run()
    return model

In [23]:
%%time
m = run_extra_trees(X_train, X_test, y_train, y_test, log=True);

              precision    recall  f1-score   support

           0    0.90412   0.97512   0.93828      4903
           1    0.93969   0.78945   0.85805      2408

    accuracy                        0.91397      7311
   macro avg    0.92191   0.88228   0.89816      7311
weighted avg    0.91584   0.91397   0.91185      7311

CPU times: user 2.64 s, sys: 133 ms, total: 2.78 s
Wall time: 16.9 s
