In [33]:
import joblib
import mlflow
import pandas as pd
import numpy as np
from pickle import dump
from kedro.io import PickleLocalDataSet
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import accuracy_score, classification_report, f1_score

In [34]:
X_train = pd.read_pickle("data/05_model_input/X_train.pkl")
X_test = pd.read_pickle("data/05_model_input/X_test.pkl")
y_train = pd.read_pickle("data/05_model_input/y_train.pkl")
y_test = pd.read_pickle("data/05_model_input/y_test.pkl")

In [35]:
def run_random_forest(X_train: PickleLocalDataSet, X_test: PickleLocalDataSet, y_train: PickleLocalDataSet, y_test: PickleLocalDataSet, log=False) -> PickleLocalDataSet:
    max_depth = 20
    
    model = RandomForestClassifier(max_depth=max_depth, random_state=0)

    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    params = {
    }

    f1 = f1_score(y_test, y_pred, average="weighted")

    joblib.dump(model, 'data/06_models/random_forest.pkl')
    
    
    print(classification_report(y_test, y_pred, digits=5))
    
    if log:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_experiment("/Users/firefly.eugene@gmail.com/twitter-bot-detection")

        run_id = mlflow.search_runs(experiment_ids="3889491181315524", filter_string="tags.`mlflow.runName`='random_forest'", run_view_type=1)["run_id"][0]    
        mlflow.start_run(run_id=run_id, nested=False)
#         mlflow.start_run(run_name="random_forest", nested=False)

        with mlflow.start_run(nested=True):
            mlflow.set_tags({
                "lib": "sklearn",
                "features": X_train.columns.values,
#                 "description": "Standard scaler added"
            })

            mlflow.log_param('max_depth', max_depth)
            mlflow.log_params(params)
            mlflow.log_metric("f1", f1, 1)
            mlflow.log_artifact('data/05_model_input/X_test.pkl')
        mlflow.end_run()
    return model

In [36]:
%%time
m = run_random_forest(X_train, X_test, y_train, y_test, log=True);

              precision    recall  f1-score   support

           0    0.92011   0.97716   0.94777      4903
           1    0.94677   0.82724   0.88298      2408

    accuracy                        0.92778      7311
   macro avg    0.93344   0.90220   0.91538      7311
weighted avg    0.92889   0.92778   0.92643      7311

CPU times: user 9.68 s, sys: 47.7 ms, total: 9.73 s
Wall time: 22.1 s
