Skip to content

Commit

Permalink
Use kwargs in metrics (#286)
Browse files Browse the repository at this point in the history
Change `metric_by_group` and `make_group_metric` to understand `**kwargs`. This removes the need for lots of small wrapper functions

Signed-off-by: Richard Edgar <riedgar@microsoft.com>
  • Loading branch information
riedgar-ms committed Feb 5, 2020
1 parent eccb6d0 commit 2ffe87c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 153 deletions.
32 changes: 20 additions & 12 deletions fairlearn/metrics/_metrics_engine.py
Expand Up @@ -9,11 +9,13 @@
_MESSAGE_SIZE_MISMATCH = "Array {0} is not the same size as {1}"


def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_weight=None):
"""Apply a metric to each subgroup of a set of data.
def metric_by_group(metric_function,
y_true, y_pred, group_membership,
sample_weight=None,
**kwargs):
r"""Apply a metric to each subgroup of a set of data.
:param metric_function: Function with signature ``(y_true, y_pred, sample_weight=None)``
which returns a scalar
:param metric_function: Function ``(y_true, y_pred, sample_weight=None, \*\*kwargs)``
:param y_true: Array of ground-truth values
Expand All @@ -23,6 +25,8 @@ def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_we
:param sample_weight: Optional weights to apply to each input value
:param \*\*kwargs: Optional arguments to be passed to the `metric_function`
:return: Object containing the result of applying ``metric_function`` to the entire dataset
and to each group identified in ``group_membership``.
If the ``metric_function`` returns a scalar, then additional fields are populated
Expand All @@ -48,9 +52,9 @@ def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_we
# Evaluate the overall metric with the numpy arrays
# This ensures consistency in how metric_function is called
if s_w is not None:
result.overall = metric_function(y_a, y_p, sample_weight=s_w)
result.overall = metric_function(y_a, y_p, sample_weight=s_w, **kwargs)
else:
result.overall = metric_function(y_a, y_p)
result.overall = metric_function(y_a, y_p, **kwargs)

groups = np.unique(group_membership)
for group in groups:
Expand All @@ -62,9 +66,12 @@ def metric_by_group(metric_function, y_true, y_pred, group_membership, sample_we
group_weight = s_w[group_indices]
result.by_group[group] = metric_function(group_actual,
group_predict,
sample_weight=group_weight)
sample_weight=group_weight,
**kwargs)
else:
result.by_group[group] = metric_function(group_actual, group_predict)
result.by_group[group] = metric_function(group_actual,
group_predict,
**kwargs)

return result

Expand All @@ -73,19 +80,20 @@ def make_group_metric(metric_function):
"""Turn a regular metric into a grouped metric.
:param metric_function: The function to be wrapped. This must have signature
``(y_true, y_pred, sample_weight)``
``(y_true, y_pred, sample_weight, **kwargs)``
:type metric_function: func
:return: A wrapped version of the supplied metric_function. It will have
signature ``(y_true, y_pred, group_membership, sample_weight)``
signature ``(y_true, y_pred, group_membership, sample_weight, **kwargs)``
:rtype: func
"""
def wrapper(y_true, y_pred, group_membership, sample_weight=None):
def wrapper(y_true, y_pred, group_membership, sample_weight=None, **kwargs):
return metric_by_group(metric_function,
y_true,
y_pred,
group_membership,
sample_weight)
sample_weight,
**kwargs)

# Improve the name of the returned function
wrapper.__name__ = "group_{0}".format(metric_function.__name__)
Expand Down
163 changes: 22 additions & 141 deletions fairlearn/metrics/_skm_wrappers.py
Expand Up @@ -6,135 +6,31 @@
from ._metrics_engine import make_group_metric, metric_by_group


def group_accuracy_score(y_true, y_pred, group_membership, *,
normalize=True,
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.accuracy_score` routine.
group_accuracy_score = make_group_metric(skm.accuracy_score)
"""A grouped wrapper around the :py:func:`sklearn.metrics.accuracy_score` routine."""

The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_acc_wrapper(y_true, y_pred, sample_weight=None):
return skm.accuracy_score(y_true, y_pred,
normalize,
sample_weight)
group_confusion_matrix = make_group_metric(skm.confusion_matrix)
"""A grouped wrapper around the :py:func:`sklearn.metrics.confusion_matrix` routine."""

return metric_by_group(internal_acc_wrapper, y_true, y_pred, group_membership, sample_weight)


def group_confusion_matrix(y_true, y_pred, group_membership, *,
labels=None,
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.confusion_matrix` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_cm_wrapper(y_true, y_pred, sample_weight=None):
return skm.confusion_matrix(y_true, y_pred,
labels,
sample_weight)

return metric_by_group(internal_cm_wrapper, y_true, y_pred, group_membership, sample_weight)


def group_precision_score(y_true, y_pred, group_membership, *,
labels=None, pos_label=1, average='binary',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.precision_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_prec_wrapper(y_true, y_pred, sample_weight=None):
return skm.precision_score(y_true, y_pred,
labels=labels, pos_label=pos_label, average=average,
sample_weight=sample_weight)

return metric_by_group(internal_prec_wrapper, y_true, y_pred, group_membership, sample_weight)


def group_recall_score(y_true, y_pred, group_membership, *,
labels=None, pos_label=1, average='binary',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.recall_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_recall_wrapper(y_true, y_pred, sample_weight=None):
return skm.recall_score(y_true, y_pred,
labels=labels, pos_label=pos_label, average=average,
sample_weight=sample_weight)

return metric_by_group(internal_recall_wrapper,
y_true, y_pred, group_membership, sample_weight)


def group_roc_auc_score(y_true, y_pred, group_membership, *,
average='macro', max_fpr=None,
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.roc_auc_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_ras_wrapper(y_true, y_pred, sample_weight=None):
return skm.roc_auc_score(y_true, y_pred,
average=average, max_fpr=max_fpr,
sample_weight=sample_weight)

return metric_by_group(internal_ras_wrapper,
y_true, y_pred, group_membership, sample_weight)


def group_zero_one_loss(y_true, y_pred, group_membership, *,
normalize=True,
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.zero_one_loss` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_zol_wrapper(y_true, y_pred, sample_weight=None):
return skm.zero_one_loss(y_true, y_pred,
normalize=normalize,
sample_weight=sample_weight)

return metric_by_group(internal_zol_wrapper, y_true, y_pred, group_membership, sample_weight)

# --------------------------------------------------------------------------------------
group_precision_score = make_group_metric(skm.precision_score)
"""A grouped wrapper around the :py:func:`sklearn.metrics.precision_score` routine
"""

group_recall_score = make_group_metric(skm.recall_score)
"""A grouped wrapper around the :py:func:`sklearn.metrics.recall_score` routine
"""

def group_mean_squared_error(y_true, y_pred, group_membership, *,
multioutput='uniform_average',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.mean_squared_error` routine.
group_roc_auc_score = make_group_metric(skm.roc_auc_score)
"""A grouped wrapper around the :py:func:`sklearn.metrics.roc_auc_score` routine
"""

The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_mse_wrapper(y_true, y_pred, sample_weight=None):
return skm.mean_squared_error(y_true, y_pred,
multioutput=multioutput,
sample_weight=sample_weight)
group_zero_one_loss = make_group_metric(skm.zero_one_loss)
"""A grouped wrapper around the :py:func:`sklearn.metrics.zero_one_loss` routine
"""

return metric_by_group(internal_mse_wrapper,
y_true, y_pred, group_membership, sample_weight=sample_weight)
group_mean_squared_error = make_group_metric(skm.mean_squared_error)
"""A grouped wrapper around the :py:func:`sklearn.metrics.mean_squared_error` routine
"""


def group_root_mean_squared_error(y_true, y_pred, group_membership, *,
Expand All @@ -157,24 +53,9 @@ def internal_rmse_wrapper(y_true, y_pred, sample_weight=None):
y_true, y_pred, group_membership, sample_weight=sample_weight)


def group_r2_score(y_true, y_pred, group_membership, *,
multioutput='uniform_average',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.r2_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_r2_wrapper(y_true, y_pred, sample_weight=None):
return skm.r2_score(y_true, y_pred,
multioutput=multioutput,
sample_weight=sample_weight)

return metric_by_group(internal_r2_wrapper,
y_true, y_pred, group_membership, sample_weight=sample_weight)

group_r2_score = make_group_metric(skm.r2_score)
"""A grouped wrapper around the :py:func:`sklearn.metrics.r2_score` routine
"""

group_max_error = make_group_metric(skm.max_error)
"""A grouped wrapper around the :py:func:`sklearn.metrics.max_error` routine
Expand Down

0 comments on commit 2ffe87c

Please sign in to comment.