Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
skadio committed Sep 7, 2023
1 parent 467d7e8 commit 60645cf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 26 deletions.
31 changes: 5 additions & 26 deletions jurity/utils_proba.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def check_inputs_proba(predictions: Union[List, np.ndarray, pd.Series],
memberships: Union[List[List], np.ndarray, pd.Series, pd.DataFrame],
surrogates: Union[List, np.ndarray, pd.Series],
membership_labels: Union[int, float, str, List[int]],
membership_names: List[str] = None,
must_have_labels: bool = False,
labels: Union[List, np.ndarray, pd.Series] = None,
membership_names = None):
labels: Union[List, np.ndarray, pd.Series] = None):
check_input_type(surrogates)

len_surrogate_class = len(surrogates)
Expand All @@ -101,28 +101,6 @@ def check_inputs_proba(predictions: Union[List, np.ndarray, pd.Series],
check_true(len(membership_labels) < len_likelihoods,
ValueError("Protected label must be less than number of classes"))

# Check that our arrays are all the same length
if must_have_labels:
check_true(labels is not None, ValueError("Metric must have labels"))

check_input_type(labels)
check_input_1d(labels)
check_binary(labels)
check_elementwise_input_type(labels)

check_true(len(labels) == len(predictions) == len(memberships),
InputShapeError("",
f"Shapes of inputs do not match. "
f"you supplied lengths of labels: "
f"{len(labels)}, predictions: {len(predictions)}"
f", is_member: {len(memberships)}."))
else:
check_true(len(predictions) == len(memberships),
InputShapeError("",
f"Shapes of inputs do not match. "
f"You supplied array lengths "
f"predictions: {len(predictions)}, is_member: {len(memberships)}."))


def get_bootstrap_results(predictions: Union[List, np.ndarray, pd.Series],
memberships: Union[List, np.ndarray, pd.Series, List[List], pd.DataFrame],
Expand Down Expand Up @@ -177,10 +155,11 @@ def get_bootstrap_results(predictions: Union[List, np.ndarray, pd.Series],
membership_names = ["A", "B"]

if labels is None:
check_inputs_proba(predictions, memberships, surrogates, membership_labels)
check_inputs_proba(predictions, memberships, surrogates, membership_labels,
membership_names=membership_names)
else:
check_inputs_proba(predictions, memberships, surrogates, membership_labels,
must_have_labels=True, labels=labels)
membership_names=membership_names, must_have_labels=True, labels=labels)

summary_df = SummaryData.summarize(predictions, memberships, surrogates, labels, membership_names)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_utils_proba.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ def test_membership_as_df(self):
"""
results = get_bootstrap_results(self.test_data["prediction"], self.surrogate_df.set_index("surrogate"),
self.test_data["surrogate"], [1, 2], self.test_data["label"])

print(results)

self.assertTrue(isinstance(results, pd.DataFrame), "get_bootstrap_results does not return a Pandas DataFrame.")
self.assertTrue(
{Constants.FPR, Constants.FNR, Constants.TNR, Constants.TPR, Constants.ACC}.issubset(set(results.columns)),
Expand Down

0 comments on commit 60645cf

Please sign in to comment.