In [None]:
%pip install -qU experiment-results-manager \
    s3fs \
    gcsfs \
    seaborn scikit-learn

In [None]:
import pandas as pd
import experiment_results_manager as erm
import pickle

from experiment_results_manager.artifact import ArtifactType
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.datasets import load_iris
from IPython.display import display, HTML

In [None]:
iris = load_iris()
feature_names = iris.feature_names
iris_df = pd.DataFrame(iris.data, columns=feature_names)
iris_df['target'] = iris.target
iris_df

In [None]:
# Train test split
X_train, X_test, y_train, y_test = train_test_split(
    iris_df.drop('target', axis=1), 
    iris_df['target'],
)

# Let's try some different classifiers
rf = RandomForestClassifier(random_state=666)
rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)

gbt = GradientBoostingClassifier(random_state=666)
gbt.fit(X_train, y_train)
gbt_pred = gbt.predict(X_test)

dt = DecisionTreeClassifier(max_depth = 3, random_state = 1)
dt.fit(X_train, y_train)
dt_pred = dt.predict(X_test)

In [None]:
from matplotlib import pyplot as plt
from numpy import mean
from sklearn.metrics import confusion_matrix
import seaborn as sns

def evaluate_and_log_run(model, y_test, y_pred, variant_id):
    er = erm.ExperimentRun(
        experiment_id="iris",
        variant_id=variant_id,
        params={
            "model": type(model).__name__
        },
        features=feature_names,
        metrics={
            "accuracy": accuracy_score(y_test, y_pred),
            "precision": precision_score(y_test, y_pred, average='weighted'),
            "recall": recall_score(y_test, y_pred, average='weighted'),
            "f1": f1_score(y_test, y_pred, average='weighted')
        }
    )

    er.log_dict("custom_dict", data={
        "avg_pred_value": mean(y_pred)
    })

    er.log_artifact(
        pickle.dumps(model), 
        artifact_id="model", 
        filename="model.pickle", 
        artifact_type=ArtifactType.BINARY)
    
    plt.clf()
    fig = sns.heatmap(
        confusion_matrix(y_test, dt_pred), 
        annot=True, 
        cmap="Blues"
    )
    fig.set_xlabel('Predicted')
    fig.set_ylabel('Actual')
    fig.set_title(variant_id)
    er.log_figure(fig, "confusion_matrix")
    
    return er

er_rf = evaluate_and_log_run(rf, y_test, rf_pred, "random_forest")
er_gbt = evaluate_and_log_run(gbt, y_test, gbt_pred, "gradient_boosting")
er_dt = evaluate_and_log_run(dt, y_test, dt_pred, "decision_tree")
plt.clf()

In [None]:
display(HTML(erm.compare_runs(er_rf, er_gbt, er_dt)))