Skip to content

Commit

Permalink
Feature/calibration (#6)
Browse files Browse the repository at this point in the history
* Added first draft of calibration possibility

* Added calibration functions and plots and made some evaluation code more generic.

* Beautified some code and made a more generic get_metrics function instead of double coding.

* Clean up formatting

* Modified the label/legend information for some calibration plots

* Allow the location of the config.yaml to be specified in an environment variable

* Allow any prefix for metrics

* Make if statement easier to read

Co-authored-by: Thijs Vereijken <thijs.vereijken@eneco.com>
Co-authored-by: Erik Jan de Vries <erikjandevries@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 5, 2020
1 parent c2a8ea7 commit 148b147
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 35 deletions.
109 changes: 82 additions & 27 deletions src/myautoml/evaluation/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,83 @@
from sklearn.model_selection import cross_validate

from myautoml.visualisation.evaluation.binary_classifier import (
save_roc_curve, save_cum_precision, save_prediction_distribution, save_lift_deciles, save_precision_recall_curve)
save_roc_curve, save_cum_precision, save_prediction_distribution, save_lift_deciles, save_precision_recall_curve,
save_calibration_curve, save_calibration_curve_zoom)

_logger = logging.getLogger(__name__)


def evaluate_binary_classifier(model, data, temp_dir, plots='all'):
_logger.debug(f"Starting evaluation for binary classifier")
def get_metrics(model, data, prefix=None):
_logger.debug(f"Starting computing the metrics")
if prefix is None:
prefix = ''
elif len(prefix) > 0 and (not prefix[-1:] == "_"):
prefix = prefix + "_"

metrics = {}
for label in data.keys():
x = data[label]['x']
y_true = data[label]['y']
y_pred = data[label]['y_pred'] = model.predict(x)
y_pred_proba = data[label]['y_pred_proba'] = model.predict_proba(x)[:, 1]

metrics[f"roc_auc_{label}"] = roc_auc_score(y_true, y_pred_proba)
metrics[f"average_precision_{label}"] = average_precision_score(y_true, y_pred_proba)
metrics[f"accuracy_{label}"] = accuracy_score(y_true, y_pred)
metrics[f"f1_{label}"] = f1_score(y_true, y_pred)
metrics[f"precision_{label}"] = precision_score(y_true, y_pred)
metrics[f"recall_{label}"] = recall_score(y_true, y_pred)
metrics[f"{prefix}roc_auc_{label}"] = roc_auc_score(y_true, y_pred_proba)
metrics[f"{prefix}average_precision_{label}"] = average_precision_score(y_true, y_pred_proba)
metrics[f"{prefix}accuracy_{label}"] = accuracy_score(y_true, y_pred)
metrics[f"{prefix}f1_{label}"] = f1_score(y_true, y_pred)
metrics[f"{prefix}precision_{label}"] = precision_score(y_true, y_pred)
metrics[f"{prefix}recall_{label}"] = recall_score(y_true, y_pred)

return metrics


def get_plots(temp_dir, data, plots, plot_path='evaluation'):
artifacts = {}

# Standard evaluation plots
if 'roc' in plots:
roc_curve_path = save_roc_curve(temp_dir, data)
if roc_curve_path:
artifacts[roc_curve_path] = plot_path

if 'pr' in plots:
pr_curve_path = save_precision_recall_curve(temp_dir, data)
if pr_curve_path:
artifacts[pr_curve_path] = plot_path

if 'lift_deciles' in plots:
lift_deciles_path = save_lift_deciles(temp_dir, data)
if lift_deciles_path:
artifacts[lift_deciles_path] = plot_path

if 'cum_precision' in plots:
cum_precision_path = save_cum_precision(temp_dir, data)
if cum_precision_path:
artifacts[cum_precision_path] = plot_path

if 'distribution' in plots:
distribution_path = save_prediction_distribution(temp_dir, data)
if distribution_path:
artifacts[distribution_path] = plot_path

# Calibration plots
if 'curve' in plots:
calibration_curve_path = save_calibration_curve(temp_dir, data)
if calibration_curve_path:
artifacts[calibration_curve_path] = plot_path

if 'curve' in plots:
calibration_curve_zoom_path = save_calibration_curve_zoom(temp_dir, data)
if calibration_curve_zoom_path:
artifacts[calibration_curve_zoom_path] = plot_path

return artifacts


def evaluate_binary_classifier(model, data, temp_dir, plots='all'):
_logger.debug(f"Starting evaluation for binary classifier")

metrics = get_metrics(model, data)

_logger.debug(f"Starting cross-validation for binary classifier")
scorers = ['roc_auc', 'accuracy', 'f1', 'average_precision', 'precision', 'recall']
Expand All @@ -33,30 +90,28 @@ def evaluate_binary_classifier(model, data, temp_dir, plots='all'):
for scorer in scorers:
metrics[f"{scorer}_cv"] = cv_results[f"test_{scorer}"].mean()

artifacts = {}

if not (plots is None or plots == ""):
if plots is None or plots == "":
artifacts = {}
else:
if plots == 'all':
plots = ['roc', 'pr', 'lift_deciles', 'cum_precision', 'distribution']

if 'roc' in plots:
roc_curve_path = save_roc_curve(temp_dir, data)
artifacts[roc_curve_path] = 'evaluation'
artifacts = get_plots(temp_dir, data, plots, plot_path='evaluation')

return metrics, artifacts


if 'pr' in plots:
pr_curve_path = save_precision_recall_curve(temp_dir, data)
artifacts[pr_curve_path] = 'evaluation'
def evaluate_calibration(model, data, temp_dir, plots='all'):
_logger.debug(f"Starting evaluation calibration for binary classifier")

if 'lift_deciles' in plots:
lift_deciles_path = save_lift_deciles(temp_dir, data)
artifacts[lift_deciles_path] = 'evaluation'
metrics = get_metrics(model, data, prefix='calibration')

if 'cum_precision' in plots:
cum_precision_path = save_cum_precision(temp_dir, data)
artifacts[cum_precision_path] = 'evaluation'
if plots is None or plots == "":
artifacts = {}
else:
if plots == 'all':
plots = ['curve', 'curve_zoom', 'roc', 'pr', 'lift_deciles', 'cum_precision', 'distribution']

if 'distribution' in plots:
distribution_path = save_prediction_distribution(temp_dir, data)
artifacts[distribution_path] = 'evaluation'
artifacts = get_plots(temp_dir, data, plots, plot_path='evaluation_calibration')

return metrics, artifacts
3 changes: 3 additions & 0 deletions src/myautoml/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Mapping
import logging.config
import os
from pathlib import Path

from box import Box
Expand Down Expand Up @@ -43,6 +44,8 @@ def start_script(dotenv='.env.general',
include_environment=True,
include_default_config=True):
load_dotenv(dotenv)
if Path(config_yaml).suffix not in [".yml", ".yaml"]:
config_yaml = os.getenv(config_yaml)
config = load_config(config_yaml, include_environment=include_environment, include_default_config=include_default_config)
Path(config.logging.handlers.debug_file_handler.filename).parent.mkdir(parents=True, exist_ok=True)
Path(config.logging.handlers.info_file_handler.filename).parent.mkdir(parents=True, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion src/myautoml/utils/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tracking import track_model, track_model_from_file, track_model_data, log_sk_model
from .tracking import get_model, track_model, track_model_from_file, track_model_data, log_sk_model
from .models import get_registered_model, register_model
5 changes: 3 additions & 2 deletions src/myautoml/utils/mlflow/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ def track_model_data(run_id: str,
client.log_artifacts(run_id, local_dir, artifact_path)


def log_sk_model(sk_model, registered_model_name=None, params=None, tags=None, metrics=None, artifacts=None):
def log_sk_model(sk_model, registered_model_name=None, params=None, tags=None, metrics=None, artifacts=None,
artifact_path='model'):
_logger.info("Logging Scikit-Learn model to MLflow")
mlflow.sklearn.log_model(sk_model=sk_model,
artifact_path='model',
artifact_path=artifact_path,
conda_env='./environment.yml',
registered_model_name=registered_model_name)
mlflow.log_params(params)
Expand Down
42 changes: 42 additions & 0 deletions src/myautoml/visualisation/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
from sklearn.calibration import calibration_curve

from myautoml.visualisation.colors import TEST_COLOR, BASELINE_COLOR

Expand Down Expand Up @@ -134,3 +135,44 @@ def plot_prediction_distribution(ax, y_pred_proba, *args, **kwargs):
ax.hist(y_pred_proba, bins=20, density=True, *args, **kwargs)
ax.legend(loc='upper right')
return ax


def plot_calibration_curve(ax, y_true, y_pred_proba, label=None, color=TEST_COLOR, legend_loc='best',
strategy='uniform', max_val=1):
fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_pred_proba, n_bins=20,
strategy=strategy)

ax.set_title('Calibration plots (reliability curve)')
ax.set_ylabel("Fraction of positives")
ax.set_xlabel("Predicted value")
ax.set_ylim([-0.05 * max_val, 1.05 * max_val])

ax.plot(mean_predicted_value, fraction_of_positives, label=label,
marker='+', markeredgecolor='black', color=color)
# Line for the calibration reference
ax.plot([0, max_val], [0, max_val], linestyle='dotted', color=BASELINE_COLOR, label='Perfectly calibrated')
ax.legend(loc=legend_loc)

return ax


def plot_calibration_curve_zoom(ax, y_true, y_pred_proba, label=None, color=TEST_COLOR, legend_loc='best',
strategy='quantile', max_val=None):
fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_pred_proba, n_bins=20,
strategy=strategy)

if not max_val:
max_val = max(y_pred_proba)

ax.set_title('Calibration plots (reliability curve)')
ax.set_ylabel("Fraction of positives")
ax.set_xlabel("Predicted value")
ax.set_ylim([-0.05 * max_val, 1.05 * max_val])

ax.plot(mean_predicted_value, fraction_of_positives, label=label,
marker='+', markeredgecolor='black', color=color)
# Line for the calibration reference
ax.plot([0, max_val], [0, max_val], linestyle='dotted', color=BASELINE_COLOR, label='Perfectly calibrated')
ax.legend(loc=legend_loc)

return ax
51 changes: 46 additions & 5 deletions src/myautoml/visualisation/evaluation/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from myautoml.visualisation.colors import EVALUATION_COLORS, TRAIN_COLOR, TEST_COLOR

from . import plot_roc, plot_cum_precision, plot_lift_deciles, plot_precision_recall, plot_prediction_distribution
from . import plot_roc, plot_cum_precision, plot_lift_deciles, plot_precision_recall, plot_prediction_distribution, \
plot_calibration_curve, plot_calibration_curve_zoom

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,14 +83,54 @@ def save_prediction_distribution(save_dir, data):
save_path = Path(save_dir) / 'prediction_distribution.png'
fig, ax = plt.subplots()
try:
plot_prediction_distribution(ax, (data['train']['y_pred_proba'], data['test']['y_pred_proba']),
label=('train', 'test'),
color=(TRAIN_COLOR, TEST_COLOR))

if 'train' in data.keys():
plot_prediction_distribution(ax, (data['train']['y_pred_proba'], data['test']['y_pred_proba']),
label=('train', 'test'),
color=(TRAIN_COLOR, TEST_COLOR))
else:
plot_prediction_distribution(ax, data['test']['y_pred_proba'],
label='test',
color=TEST_COLOR)
fig.savefig(save_path)
except Exception as e:
_logger.warning(f"Error plotting the prediction distribution: {str(e)}")
save_path = None
finally:
plt.close(fig)
return save_path


def save_calibration_curve(save_dir, data):
_logger.debug("Plotting the calibration curve")
save_path = Path(save_dir) / 'calibration_curve.png'
fig, ax = plt.subplots()
try:
plot_calibration_curve(ax, data['test']['y'], data['test']['y_pred_proba'],
label='test',
color=TEST_COLOR)

fig.savefig(save_path)
except Exception as e:
_logger.warning(f"Error plotting the calibration curve: {str(e)}")
save_path = None
finally:
plt.close(fig)
return save_path


def save_calibration_curve_zoom(save_dir, data):
_logger.debug("Plotting the calibration curve zoom")
save_path = Path(save_dir) / 'calibration_curve_zoom.png'
fig, ax = plt.subplots()
try:
plot_calibration_curve_zoom(ax, data['test']['y'], data['test']['y_pred_proba'],
label='test',
color=TEST_COLOR)

fig.savefig(save_path)
except Exception as e:
_logger.warning(f"Error plotting the calibration curve zoom: {str(e)}")
save_path = None
finally:
plt.close(fig)
return save_path

0 comments on commit 148b147

Please sign in to comment.