Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
skadio committed Sep 7, 2023
1 parent 1475e9a commit f8acba9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
5 changes: 3 additions & 2 deletions jurity/fairness/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def get_all_scores(labels: Union[List, np.ndarray, pd.Series],
membership_labels = [1]

if bootstrap_results is None:
bootstrap_results = get_bootstrap_results(predictions, memberships, surrogates, membership_labels)
bootstrap_results = get_bootstrap_results(predictions, memberships, surrogates,
membership_labels, labels)

# Output df
df = pd.DataFrame(columns=["Metric", "Value", "Ideal Value", "Lower Bound", "Upper Bound"])
Expand Down Expand Up @@ -144,7 +145,7 @@ def _get_score_logic(metric, name,
score = metric.get_score(labels, predictions, memberships, membership_labels)
else:
if name == "StatisticalParity":
score = metric.get_score(predictions, memberships, membership_labels, bootstrap_results)
score = metric.get_score(predictions, memberships, surrogates, membership_labels, bootstrap_results)
elif name in ["AverageOdds", "EqualOpportunity", "FNRDifference", "PredictiveEquality"]:
score = metric.get_score(labels, predictions, memberships, surrogates,
membership_labels, bootstrap_results)
Expand Down
1 change: 0 additions & 1 deletion jurity/fairness/average_odds.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def get_score(labels: Union[List, np.ndarray, pd.Series],
must_have_labels=True, labels=labels)
bootstrap_results = get_bootstrap_results(predictions, memberships, surrogates, membership_labels,
labels)
print(bootstrap_results)

tpr_group_1, tpr_group_2 = unpack_bootstrap(bootstrap_results, Constants.TPR, membership_labels)
fpr_group_1, fpr_group_2 = unpack_bootstrap(bootstrap_results, Constants.FPR, membership_labels)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_fairness_proba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ def test_quick_start(self):
metric = BinaryFairnessMetrics.StatisticalParity()
score = metric.get_score(predictions, memberships, surrogates)

def test_quick_start_avg_odds(self):
# Data
labels = [1, 1, 1, 0, 1, 0, 1, 1, 1, 1]
predictions = [0, 0, 0, 1, 0, 0, 0, 0, 0, 1]
memberships = [[0.2, 0.8], [0.4, 0.6], [0.2, 0.8], [0.9, 0.1], [0.3, 0.7],
[0.8, 0.2], [0.6, 0.4], [0.8, 0.2], [0.1, 0.9], [0.7, 0.3]]
surrogates = [0, 2, 0, 1, 3, 0, 0, 1, 1, 2]
# membership_labels = [1]
metric = BinaryFairnessMetrics.AverageOdds()
score = metric.get_score(labels, predictions, memberships, surrogates)
print(score)

def test_all_scores(self):
labels = [1, 1, 1, 0, 1, 0, 1, 1, 1, 1]
predictions = [0, 0, 0, 1, 0, 0, 0, 0, 0, 1]
Expand Down

0 comments on commit f8acba9

Please sign in to comment.