In [1]:
import yaml
import json
import pandas as pd
from mlflow import start_run, log_metric, log_param, set_experiment
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

with open("params.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

In [2]:
train_labels = pd.read_csv(config["train_y_dataset_path"])
train_predictions = pd.read_csv(config["train_predictions_path"])

In [3]:
validation_labels = pd.read_csv(config["validation_y_dataset_path"])
validation_predictions = pd.read_csv(config["validation_predictions_path"])

In [4]:
set_experiment(config["experiment_name"])
with start_run(run_name=config["run_name"]):
    log_metric("train accuracy", accuracy_score(train_labels, train_predictions))
    log_metric("train recall", recall_score(train_labels, train_predictions))
    log_metric("train precision", precision_score(train_labels, train_predictions))
    log_metric("train f1_score", f1_score(train_labels, train_predictions))
    
    log_metric("validation accuracy", accuracy_score(validation_labels, validation_predictions))
    log_metric("validation recall", recall_score(validation_labels, validation_predictions))
    log_metric("validation precision", precision_score(validation_labels, validation_predictions))
    log_metric("validation f1_score", f1_score(validation_labels, validation_predictions))

In [5]:
with open(config["metrics_path"], 'w') as file:
    metrics = {
        "train accuracy": accuracy_score(train_labels, train_predictions),
        "train recall": recall_score(train_labels, train_predictions),
        "train precision": precision_score(train_labels, train_predictions),
        "train f1_score": f1_score(train_labels, train_predictions),

        "validation accuracy": accuracy_score(validation_labels, validation_predictions),
        "validation recall": recall_score(validation_labels, validation_predictions),
        "validation precision": precision_score(validation_labels, validation_predictions),
        "validation f1_score": f1_score(validation_labels, validation_predictions)
    }
    file.write(json.dumps(metrics))