From 2ffe87c71ebbde66f058c758da39258cc59fafc6 Mon Sep 17 00:00:00 2001 From: Richard Edgar Date: Tue, 4 Feb 2020 20:55:24 -0500 Subject: [PATCH] Use kwargs in metrics (#286) Change `metric_by_group` and `make_group_metric` to understand `**kwargs`. This removes the need for lots of small wrapper functions Signed-off-by: Richard Edgar --- fairlearn/metrics/_metrics_engine.py | 32 ++++-- fairlearn/metrics/_skm_wrappers.py | 163 ++++----------------------- 2 files changed, 42 insertions(+), 153 deletions(-) diff --git a/fairlearn/metrics/_metrics_engine.py b/fairlearn/metrics/_metrics_engine.py index 49fc831c0..3fe6273fd 100644 --- a/fairlearn/metrics/_metrics_engine.py +++ b/fairlearn/metrics/_metrics_engine.py @@ -9,11 +9,13 @@ _MESSAGE_SIZE_MISMATCH = "Array {0} is not the same size as {1}" -def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_weight=None): - """Apply a metric to each subgroup of a set of data. +def metric_by_group(metric_function, + y_true, y_pred, group_membership, + sample_weight=None, + **kwargs): + r"""Apply a metric to each subgroup of a set of data. - :param metric_function: Function with signature ``(y_true, y_pred, sample_weight=None)`` - which returns a scalar + :param metric_function: Function ``(y_true, y_pred, sample_weight=None, \*\*kwargs)`` :param y_true: Array of ground-truth values @@ -23,6 +25,8 @@ def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_we :param sample_weight: Optional weights to apply to each input value + :param \*\*kwargs: Optional arguments to be passed to the `metric_function` + :return: Object containing the result of applying ``metric_function`` to the entire dataset and to each group identified in ``group_membership``. If the ``metric_function`` returns a scalar, then additional fields are populated @@ -48,9 +52,9 @@ def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_we # Evaluate the overall metric with the numpy arrays # This ensures consistency in how metric_function is called if s_w is not None: - result.overall = metric_function(y_a, y_p, sample_weight=s_w) + result.overall = metric_function(y_a, y_p, sample_weight=s_w, **kwargs) else: - result.overall = metric_function(y_a, y_p) + result.overall = metric_function(y_a, y_p, **kwargs) groups = np.unique(group_membership) for group in groups: @@ -62,9 +66,12 @@ def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_we group_weight = s_w[group_indices] result.by_group[group] = metric_function(group_actual, group_predict, - sample_weight=group_weight) + sample_weight=group_weight, + **kwargs) else: - result.by_group[group] = metric_function(group_actual, group_predict) + result.by_group[group] = metric_function(group_actual, + group_predict, + **kwargs) return result @@ -73,19 +80,20 @@ def make_group_metric(metric_function): """Turn a regular metric into a grouped metric. :param metric_function: The function to be wrapped. This must have signature - ``(y_true, y_pred, sample_weight)`` + ``(y_true, y_pred, sample_weight, **kwargs)`` :type metric_function: func :return: A wrapped version of the supplied metric_function. It will have - signature ``(y_true, y_pred, group_membership, sample_weight)`` + signature ``(y_true, y_pred, group_membership, sample_weight, **kwargs)`` :rtype: func """ - def wrapper(y_true, y_pred, group_membership, sample_weight=None): + def wrapper(y_true, y_pred, group_membership, sample_weight=None, **kwargs): return metric_by_group(metric_function, y_true, y_pred, group_membership, - sample_weight) + sample_weight, + **kwargs) # Improve the name of the returned function wrapper.__name__ = "group_{0}".format(metric_function.__name__) diff --git a/fairlearn/metrics/_skm_wrappers.py b/fairlearn/metrics/_skm_wrappers.py index 2ab2d192f..623d65083 100644 --- a/fairlearn/metrics/_skm_wrappers.py +++ b/fairlearn/metrics/_skm_wrappers.py @@ -6,135 +6,31 @@ from ._metrics_engine import make_group_metric, metric_by_group -def group_accuracy_score(y_true, y_pred, group_membership, *, - normalize=True, - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.accuracy_score` routine. +group_accuracy_score = make_group_metric(skm.accuracy_score) +"""A grouped wrapper around the :py:func:`sklearn.metrics.accuracy_score` routine.""" - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_acc_wrapper(y_true, y_pred, sample_weight=None): - return skm.accuracy_score(y_true, y_pred, - normalize, - sample_weight) +group_confusion_matrix = make_group_metric(skm.confusion_matrix) +"""A grouped wrapper around the :py:func:`sklearn.metrics.confusion_matrix` routine.""" - return metric_by_group(internal_acc_wrapper, y_true, y_pred, group_membership, sample_weight) - - -def group_confusion_matrix(y_true, y_pred, group_membership, *, - labels=None, - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.confusion_matrix` routine. - - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_cm_wrapper(y_true, y_pred, sample_weight=None): - return skm.confusion_matrix(y_true, y_pred, - labels, - sample_weight) - - return metric_by_group(internal_cm_wrapper, y_true, y_pred, group_membership, sample_weight) - - -def group_precision_score(y_true, y_pred, group_membership, *, - labels=None, pos_label=1, average='binary', - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.precision_score` routine. - - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_prec_wrapper(y_true, y_pred, sample_weight=None): - return skm.precision_score(y_true, y_pred, - labels=labels, pos_label=pos_label, average=average, - sample_weight=sample_weight) - - return metric_by_group(internal_prec_wrapper, y_true, y_pred, group_membership, sample_weight) - - -def group_recall_score(y_true, y_pred, group_membership, *, - labels=None, pos_label=1, average='binary', - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.recall_score` routine. - - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_recall_wrapper(y_true, y_pred, sample_weight=None): - return skm.recall_score(y_true, y_pred, - labels=labels, pos_label=pos_label, average=average, - sample_weight=sample_weight) - - return metric_by_group(internal_recall_wrapper, - y_true, y_pred, group_membership, sample_weight) - - -def group_roc_auc_score(y_true, y_pred, group_membership, *, - average='macro', max_fpr=None, - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.roc_auc_score` routine. - - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_ras_wrapper(y_true, y_pred, sample_weight=None): - return skm.roc_auc_score(y_true, y_pred, - average=average, max_fpr=max_fpr, - sample_weight=sample_weight) - - return metric_by_group(internal_ras_wrapper, - y_true, y_pred, group_membership, sample_weight) - - -def group_zero_one_loss(y_true, y_pred, group_membership, *, - normalize=True, - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.zero_one_loss` routine. - - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_zol_wrapper(y_true, y_pred, sample_weight=None): - return skm.zero_one_loss(y_true, y_pred, - normalize=normalize, - sample_weight=sample_weight) - - return metric_by_group(internal_zol_wrapper, y_true, y_pred, group_membership, sample_weight) - -# -------------------------------------------------------------------------------------- +group_precision_score = make_group_metric(skm.precision_score) +"""A grouped wrapper around the :py:func:`sklearn.metrics.precision_score` routine +""" +group_recall_score = make_group_metric(skm.recall_score) +"""A grouped wrapper around the :py:func:`sklearn.metrics.recall_score` routine +""" -def group_mean_squared_error(y_true, y_pred, group_membership, *, - multioutput='uniform_average', - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.mean_squared_error` routine. +group_roc_auc_score = make_group_metric(skm.roc_auc_score) +"""A grouped wrapper around the :py:func:`sklearn.metrics.roc_auc_score` routine +""" - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_mse_wrapper(y_true, y_pred, sample_weight=None): - return skm.mean_squared_error(y_true, y_pred, - multioutput=multioutput, - sample_weight=sample_weight) +group_zero_one_loss = make_group_metric(skm.zero_one_loss) +"""A grouped wrapper around the :py:func:`sklearn.metrics.zero_one_loss` routine +""" - return metric_by_group(internal_mse_wrapper, - y_true, y_pred, group_membership, sample_weight=sample_weight) +group_mean_squared_error = make_group_metric(skm.mean_squared_error) +"""A grouped wrapper around the :py:func:`sklearn.metrics.mean_squared_error` routine +""" def group_root_mean_squared_error(y_true, y_pred, group_membership, *, @@ -157,24 +53,9 @@ def internal_rmse_wrapper(y_true, y_pred, sample_weight=None): y_true, y_pred, group_membership, sample_weight=sample_weight) -def group_r2_score(y_true, y_pred, group_membership, *, - multioutput='uniform_average', - sample_weight=None): - """Wrap the :py:func:`sklearn.metrics.r2_score` routine. - - The arguments remain the same, with `group_membership` added. - However, the only positional arguments supported are `y_true`, - `y_pred` and `group_membership`. - All others must be specified by name. - """ - def internal_r2_wrapper(y_true, y_pred, sample_weight=None): - return skm.r2_score(y_true, y_pred, - multioutput=multioutput, - sample_weight=sample_weight) - - return metric_by_group(internal_r2_wrapper, - y_true, y_pred, group_membership, sample_weight=sample_weight) - +group_r2_score = make_group_metric(skm.r2_score) +"""A grouped wrapper around the :py:func:`sklearn.metrics.r2_score` routine +""" group_max_error = make_group_metric(skm.max_error) """A grouped wrapper around the :py:func:`sklearn.metrics.max_error` routine