Skip to content

Commit

Permalink
select best model based on fairness (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed May 5, 2023
1 parent 26bb256 commit cb46370
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/scripts/binary_classifier_adult_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
)


automl = AutoML(algorithms=["Xgboost"], #, "LightGBM"],
automl = AutoML(algorithms=["Random Forest"], # "Xgboost", "LightGBM", "Random Forest", "Decision Tree"],
train_ensemble=False,
fairness_metric="demographic_parity_ratio", #
fairness_threshold=0.8,
#privileged_groups = [{"sex": "Male"}],
#unprivileged_groups = [{"sex": "Female"}],
#hill_climbing_steps=1,
#top_models_to_improve=2
#top_models_to_improve=1
)


Expand Down
16 changes: 11 additions & 5 deletions supervised/base_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def select_and_save_best(self, show_warnings=False):
m
for m in self._models
if m.is_valid()
and m.is_fast_enough(self._max_single_prediction_time)
# and m.is_fast_enough(self._max_single_prediction_time)
and m.is_fair()
]

Expand All @@ -1260,10 +1260,16 @@ def select_and_save_best(self, show_warnings=False):
)
else:
# if no models are fair, we select the most fair model
self._best_model = min(
[m for m in self._models if m.is_valid()],
key=lambda x: x.get_best_fairness(),
)
if "ratio" in self._fairness_metric.lower():
self._best_model = max(
[m for m in self._models if m.is_valid()],
key=lambda x: x.get_best_fairness(),
)
else:
self._best_model = min(
[m for m in self._models if m.is_valid()],
key=lambda x: x.get_best_fairness(),
)

else:
model_list = [
Expand Down
4 changes: 2 additions & 2 deletions supervised/tuner/mljar_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def steps(self):
if self._fairness_metric is not None:
all_steps += ["unfairness_mitigation"]
# up to 10 steps
for i in range(10):
all_steps += [f"unfairness_mitigation_update_{i+1}"]
#for i in range(10):
# all_steps += [f"unfairness_mitigation_update_{i+1}"]

if self._start_random_models > 1:
all_steps += ["not_so_random"]
Expand Down

0 comments on commit cb46370

Please sign in to comment.