Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label flipping fixes #148

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions src/aequitas/flow/methods/preprocessing/label_flipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import inspect
import pandas as pd
import math
from typing import Optional, Tuple, Literal, Union, Callable
import numpy as np
from sklearn.ensemble import BaggingClassifier
Expand All @@ -15,8 +16,8 @@
class LabelFlipping(PreProcessing):
def __init__(
self,
flip_rate: float = 0.1,
disparity_target: Optional[float] = None,
max_flip_rate: float = 0.1,
disparity_target: Optional[float] = 0.05,
score_threshold: Optional[float] = None,
bagging_max_samples: float = 0.5,
bagging_base_estimator: Union[
Expand All @@ -34,8 +35,16 @@ def __init__(

Parameters
----------
flip_rate : float, optional
max_flip_rate : float, optional
Maximum fraction of the training data to flip, by default 0.1
disparity_target : float, optional
The target disparity between the groups (difference between the prevalence
of a group and the mean prevalence). By default None, which means the
method will attempt to equalize the prevalence of the groups.
score_threshold : float, optional
The threshold above which the labels are flipped. By default None,
which means the method will flip the labels of the instances with
a score value higher than 0.
bagging_max_samples : float, optional
The number of samples to draw from X to train each base estimator of the
bagging classifier (with replacement).
Expand All @@ -45,17 +54,17 @@ def __init__(
bagging_n_estimators : int, optional
The number of base estimators in the ensemble, by default 10.
fair_ordering : bool, optional
Whether to take additional fairness criteria into account when flipping
Whether to take additional fairness criteria into account when flipping
labels, only modifying the labels that contribute to equalizing the
prevalence of the groups. By default True.
ordering_method : str, optional
The method used to calculate the margin of the base estimator. If
"ensemble_margin", calculates the ensemble margins based on the binary
predictions of the classifiers. If "residuals", oreders the missclafied
instances based on the average residuals of the classifiers predictions. By
The method used to calculate the margin of the base estimator. If
"ensemble_margin", calculates the ensemble margins based on the binary
predictions of the classifiers. If "residuals", orders the misclassified
instances based on the average residuals of the classifiers predictions. By
default "ensemble_margin".
unawareness_features : list, optional
The sensitive attributes (or proxies) to ignore when fitting the ensemble
The sensitive attributes (or proxies) to ignore when fitting the ensemble
to enable fairness through unawareness.
seed : int, optional
The seed to use when fitting the ensemble.
Expand All @@ -67,17 +76,17 @@ def __init__(
>>> from aequitas.preprocessing import LabelFlipping
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.datasets import make_classification
>>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
>>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
n_redundant=0, random_state=42)
>>> lf = LabelFlipping(bagging_base_estimator=DecisionTreeClassifier,
flip_rate=0.1, max_depth=3)
>>> lf = LabelFlipping(bagging_base_estimator=DecisionTreeClassifier,
max_flip_rate=0.1, max_depth=3)
>>> lf.fit(X, y)
>>> X_transformed, y_transformed = lf.transform(X, y)
"""
self.logger = create_logger("methods.preprocessing.LabelFlipping")
self.logger.info("Instantiating a LabelFlipping preprocessing method.")

self.flip_rate = flip_rate
self.max_flip_rate = max_flip_rate

if disparity_target is not None:
if disparity_target < 0 or disparity_target > 1:
Expand Down Expand Up @@ -114,7 +123,7 @@ def __init__(
self.bagging_base_estimator = bagging_base_estimator(**args)
self.logger.info(
f"Created base estimator {self.bagging_base_estimator} with params {args}, "
F"discarded args:{list(set(base_estimator_args.keys()) - set(args.keys()))}"
f"discarded args:{list(set(base_estimator_args.keys()) - set(args.keys()))}"
)
self.bagging_n_estimators = bagging_n_estimators

Expand Down Expand Up @@ -159,8 +168,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
def _score_instances(self, X: pd.DataFrame, y: pd.Series) -> pd.Series:
"""Scores the instances based on the predictions of the ensemble of classifiers.

If the ordering method is "ensemble_margin", the scores are the ensemble
margins. If the ordering method is "residuals", the scores are the average
If the ordering method is "ensemble_margin", the scores are the ensemble
margins. If the ordering method is "residuals", the scores are the average
residuals of the classifiers predictions.

Parameters
Expand Down Expand Up @@ -202,19 +211,28 @@ def _score_instances(self, X: pd.DataFrame, y: pd.Series) -> pd.Series:

return scores

def _calculate_prevalence_disparity(self, y: pd.Series, s: pd.Series):
def _calculate_group_flips(self, y: pd.Series, s: pd.Series):
prevalence = y.mean()
group_prevalence = y.groupby(s).mean().to_dict()
group_disparity = {k: v - prevalence for k, v in group_prevalence.items()}
group_prevalences = y.groupby(s).mean()

return group_disparity
min_prevalence = prevalence - self.disparity_target * prevalence
max_prevalence = prevalence + self.disparity_target * prevalence

group_flips = {
group: math.ceil(min_prevalence * len(y[s == group])) - y[s == group].sum()
if group_prevalences[group] < min_prevalence
else math.floor(max_prevalence * len(y[s == group])) - y[s == group].sum()
for group in group_prevalences.index
}

return group_flips

def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Series):
"""Flips the labels of the desired fraction of the training data.

If fair_ordering is True, only flips the labels of the instances that contribute
to equalizing the prevalence of the groups.
Otherwise, the labels of the instances with the largest score values are
Otherwise, the labels of the instances with the largest score values are
flipped.

Parameters
Expand All @@ -236,10 +254,10 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie
ascending=(self.ordering_method == "ensemble_margin")
).index
)
n_flip = int(self.flip_rate * len(y))
n_flip = int(self.max_flip_rate * len(y))

if self.fair_ordering:
disparity = self._calculate_prevalence_disparity(y_flipped, s)
group_flips = self._calculate_group_flips(y_flipped, s)
flip_index = (
y_flipped.index
if self.ordering_method == "residuals"
Expand All @@ -251,12 +269,15 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie
if abs(scores.loc[i]) < self.score_threshold:
break

if (disparity[s.loc[i]] > self.disparity_target and y.loc[i] == 1) or (
disparity[s.loc[i]] < self.disparity_target and y.loc[i] == 0
if (group_flips[s.loc[i]] > 0 and y.loc[i] == 0) or (
group_flips[s.loc[i]] < 0 and y.loc[i] == 1
):
y_flipped.loc[i] = 1 - y.loc[i]
disparity = self._calculate_prevalence_disparity(y_flipped, s)
flip_count += 1
if group_flips[s.loc[i]] > 0:
group_flips[s.loc[i]] -= 1
else:
group_flips[s.loc[i]] += 1

if flip_count == n_flip:
break
Expand All @@ -282,7 +303,9 @@ def transform(
Parameters
----------
X : pd.DataFrame
Feature[s.loc[i]]ector.
Feature matrix.
y : pd.Series
Label vector.
s : pd.Series, optional
Protected attribute vector.

Expand All @@ -295,7 +318,7 @@ def transform(

if s is None and self.fair_ordering:
raise ValueError(
"Sensitive Attribute `s` not passed. Must be passed if `fair_ordering` "
"Sensitive Attribute `s` not passed. Must be passed if `fair_ordering` "
"is True."
)

Expand Down