Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT export InterpolatedThresholder as a public object and update its API (e.g., rename an argument) #918

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
* Relaxed checks made on `X` in `_validate_and_reformat_input()` since that
is the concern of the underlying estimator and not Fairlearn
* Add support for Python 3.9
* Make `InterpolatedThresholder` more visible by directly including it in
`fairlearn.postprocessing` and rename its `interpolation_dict` argument
to `threshold_interpolation`.

### v0.7.0

Expand Down
6 changes: 4 additions & 2 deletions fairlearn/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
learn how to adjust the predictor's output from the training data.
"""

from ._interpolated_thresholder import InterpolatedThresholder # noqa: F401
from ._threshold_optimizer import ThresholdOptimizer # noqa: F401
from ._plotting import plot_threshold_optimizer # noqa: F401

__all__ = [
"ThresholdOptimizer",
"plot_threshold_optimizer"
"InterpolatedThresholder",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll also want to export ThresholdOperation. That's another structured object, where we should review whether we like its current API enough to make it external.

"plot_threshold_optimizer",
"ThresholdOptimizer"
]
11 changes: 5 additions & 6 deletions fairlearn/postprocessing/_interpolated_thresholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ class InterpolatedThresholder(BaseEstimator, MetaEstimatorMixin):

At prediction time, the predictor takes as input both standard and sensitive features.
Based on the values of sensitive features, it then applies a randomized thresholding
transformation according to the provided `interpolation_dict`.
transformation according to the provided `threshold_interpolation`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change? if you don't like the original, we should go for something more concise. maybe something like threshold_info or thresholding_info or even thresholds (but that is a bit confusing, because it sounds like the structure would be simpler).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Info isn't a great name. Everything is "info", yet we don't call "width" variables "width_info". It's a stylistic improvement as well as a description improvement: interpolation and threshold are the two important things worth mentioning IMO.
If you don't think it's an improvement I don't mind scrapping it, though (not the hill I choose to die on). It was originally part of an older PR about documenting post processing (which I'm actively working on) and got refactored out into its own change for simplicity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about threshold_weights since basically each value of a threshold gets a certain probabilistic weight?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't dislike it, but that's just p0 and p1 🙂

It's really not a bad name though...


Parameters
----------
estimator :
base estimator

interpolation_dict : dict
threshold_interpolation : dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above re. naming. By the way, do we like the structure of this argument? The main reason why I've originally excluded this estimator from the module was that I wasn't sure we all liked this deep structure and I didn't want to commit to it... I usually prefer "shallower" input arguments.

maps sensitive feature values to `Bunch` that describes the
interpolation transformation via the following fields:

Expand Down Expand Up @@ -78,10 +77,10 @@ class InterpolatedThresholder(BaseEstimator, MetaEstimatorMixin):
[Online]. Available: https://arxiv.org/abs/1610.02413.
"""

def __init__(self, estimator, interpolation_dict, prefit=False,
def __init__(self, estimator, threshold_interpolation, prefit=False,
predict_method='deprecated'):
self.estimator = estimator
self.interpolation_dict = interpolation_dict
self.threshold_interpolation = threshold_interpolation
self.prefit = prefit
self.predict_method = predict_method

Expand Down Expand Up @@ -141,7 +140,7 @@ def _pmf_predict(self, X, *, sensitive_features):
enforce_binary_labels=False)

positive_probs = 0.0*base_predictions_vector
for a, interpolation in self.interpolation_dict.items():
for a, interpolation in self.threshold_interpolation.items():
interpolated_predictions = \
interpolation.p0 * interpolation.operation0(base_predictions_vector) + \
interpolation.p1 * interpolation.operation1(base_predictions_vector)
Expand Down
12 changes: 6 additions & 6 deletions fairlearn/postprocessing/_threshold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,12 +459,12 @@ def _threshold_optimization_for_simple_constraints(

# Create the solution as interpolation of multiple points with a separate
# interpolation per sensitive feature value.
interpolation_dict = {}
threshold_interpolation = {}
for sensitive_feature_value in self._tradeoff_curve.keys():
best_interpolation = self._tradeoff_curve[
sensitive_feature_value
].transpose()[i_best]
interpolation_dict[sensitive_feature_value] = Bunch(
threshold_interpolation[sensitive_feature_value] = Bunch(
p0=best_interpolation.p0,
operation0=best_interpolation.operation0,
p1=best_interpolation.p1,
Expand All @@ -484,7 +484,7 @@ def _threshold_optimization_for_simple_constraints(

return InterpolatedThresholder(
self.estimator_,
interpolation_dict,
threshold_interpolation,
prefit=True,
predict_method=self._predict_method,
).fit(None, None)
Expand Down Expand Up @@ -578,7 +578,7 @@ def _threshold_optimization_for_equalized_odds(

# create the solution as interpolation of multiple points with a separate
# interpolation per sensitive feature
interpolation_dict = {}
threshold_interpolation = {}
for sensitive_feature_value in self._tradeoff_curve.keys():
roc_result = self._tradeoff_curve[
sensitive_feature_value
Expand All @@ -599,7 +599,7 @@ def _threshold_optimization_for_equalized_odds(
/ vertical_distance_from_diagonal
)

interpolation_dict[sensitive_feature_value] = Bunch(
threshold_interpolation[sensitive_feature_value] = Bunch(
p_ignore=p_ignore,
prediction_constant=self._x_best,
p0=roc_result.p0,
Expand All @@ -621,7 +621,7 @@ def _threshold_optimization_for_equalized_odds(

return InterpolatedThresholder(
self.estimator_,
interpolation_dict,
threshold_interpolation,
prefit=True,
predict_method=self._predict_method,
).fit(None, None)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/postprocessing/test_threshold_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def test_constraints_objective_pairs(constraints, objective):
assert str(error_info.value).startswith(expected)
else:
thr_optimizer.fit(X, y, sensitive_features=sf)
res = thr_optimizer.interpolated_thresholder_.interpolation_dict
res = thr_optimizer.interpolated_thresholder_.threshold_interpolation
for key in [0, 1]:
assert res[key]['p0'] == pytest.approx(expected[key]['p0'], PREC)
assert res[key]['operation0']._operator == expected[key]['op0']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def test_threshold_optimizer_multiple_sensitive_features():
metricframe_multi.by_group.loc[(a2, a4)]).all()

# comparing string representations of interpolation dicts is sufficient
assert str(postprocess_est_combined.interpolated_thresholder_.interpolation_dict[a1+a3]) == \
str(postprocess_est_multi.interpolated_thresholder_.interpolation_dict[a13])
assert str(postprocess_est_combined.interpolated_thresholder_.interpolation_dict[a1+a4]) == \
str(postprocess_est_multi.interpolated_thresholder_.interpolation_dict[a14])
assert str(postprocess_est_combined.interpolated_thresholder_.interpolation_dict[a2+a3]) == \
str(postprocess_est_multi.interpolated_thresholder_.interpolation_dict[a23])
assert str(postprocess_est_combined.interpolated_thresholder_.interpolation_dict[a2+a4]) == \
str(postprocess_est_multi.interpolated_thresholder_.interpolation_dict[a24])
combined_interpolation = \
postprocess_est_combined.interpolated_thresholder_.threshold_interpolation
multi_interpolation = \
postprocess_est_multi.interpolated_thresholder_.threshold_interpolation
assert str(combined_interpolation[a1+a3]) == str(multi_interpolation[a13])
assert str(combined_interpolation[a1+a4]) == str(multi_interpolation[a14])
assert str(combined_interpolation[a2+a3]) == str(multi_interpolation[a23])
assert str(combined_interpolation[a2+a4]) == str(multi_interpolation[a24])