In [8]:
%pip install -qU experiment-results-manager \
    s3fs \
    gcsfs \
    seaborn scikit-learn


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [9]:
import pandas as pd
import experiment_results_manager as erm
import pickle

from experiment_results_manager.artifact import ArtifactType
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.datasets import load_iris
from IPython.display import display, HTML

In [10]:
iris = load_iris()
feature_names = iris.feature_names
iris_df = pd.DataFrame(iris.data, columns=feature_names)
iris_df['target'] = iris.target
iris_df

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2


In [11]:
# Train test split
X_train, X_test, y_train, y_test = train_test_split(
    iris_df.drop('target', axis=1), 
    iris_df['target'],
)

# Let's try some different classifiers
rf = RandomForestClassifier(random_state=666)
rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)

gbt = GradientBoostingClassifier(random_state=420)
gbt.fit(X_train, y_train)
gbt_pred = gbt.predict(X_test)

dt = DecisionTreeClassifier(max_depth=3, random_state=69)
dt.fit(X_train, y_train)
dt_pred = dt.predict(X_test)

In [12]:
from matplotlib import pyplot as plt
from numpy import mean
from sklearn.metrics import confusion_matrix
import seaborn as sns

def evaluate_and_log_run(model, y_test, y_pred, variant_id):
    er = erm.ExperimentRun(
        experiment_id="iris",
        variant_id=variant_id,
        params=model.get_params(),
        metrics={
            "accuracy": accuracy_score(y_test, y_pred),
            "precision": precision_score(y_test, y_pred, average='weighted'),
            "recall": recall_score(y_test, y_pred, average='weighted'),
            "f1": f1_score(y_test, y_pred, average='weighted')
        }
    )

    er.log_dict("custom_dict", data={
        "avg_pred_value": mean(y_pred)
    })

    er.log_artifact(
        pickle.dumps(model), 
        artifact_id="model", 
        filename="model.pickle")
    
    plt.clf()
    fig = sns.heatmap(
        confusion_matrix(y_test, dt_pred), 
        annot=True, 
        cmap="Blues"
    )
    fig.set_xlabel('Predicted')
    fig.set_ylabel('Actual')
    fig.set_title(variant_id)
    er.log_figure(fig, "confusion_matrix")
    
    return er

er_rf = evaluate_and_log_run(rf, y_test, rf_pred, "random_forest")
er_gbt = evaluate_and_log_run(gbt, y_test, gbt_pred, "gradient_boosting")
er_dt = evaluate_and_log_run(dt, y_test, dt_pred, "decision_tree")
plt.clf()

<Figure size 640x480 with 0 Axes>

In [13]:
display(HTML(erm.compare_runs(er_rf, er_gbt, er_dt)))

Unnamed: 0,Run 1,Run 2,Run 3
Experiment id,iris,iris,iris
Run id,2023_04_28__17_35_03,2023_04_28__17_35_03,2023_04_28__17_35_03
Timestamp (UTC),2023-04-28 17:35:03,2023-04-28 17:35:03,2023-04-28 17:35:03
Variant id,random_forest,gradient_boosting,decision_tree

Unnamed: 0,Run 1,Run 2,Run 3
bootstrap,True,,
ccp_alpha,0.0,0.0,0.0
class_weight,,,
criterion,gini,friedman_mse,gini
init,,,
learning_rate,,0.1,
loss,,log_loss,
max_depth,,3,3
max_features,sqrt,,
max_leaf_nodes,,,

Unnamed: 0,Run 1,Run 2,Run 3
accuracy,0.9473684210526316,0.9473684210526316,0.9473684210526316
f1,0.9470551378446114,0.9470551378446114,0.9470551378446114
precision,0.9543859649122808,0.9543859649122808,0.9543859649122808
recall,0.9473684210526316,0.9473684210526316,0.9473684210526316

Unnamed: 0,Run 1,Run 2,Run 3
avg_pred_value,1.0789473684210529,1.0789473684210529,1.0789473684210529
