In [1]:
import sys
sys.path.append("../scripts")

import data
import utils
import models
import experiments

import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import optuna

## Testing metric with one fold

In [61]:
args = {
    "dataset" : "german",
    "alpha" : 0.1,
    "n_groups" : 4,
    "model_name" : "M2FGB_eod",
    "n_trials" : 100,
    "thresh" : "ks",
    "fairness_metric" : "min_bal_acc"
}

In [62]:
X_train, Y_train, X_val, Y_val, X_test, Y_test = data.get_fold("german", 0)
A_train, A_val, A_test = experiments.get_subgroup_feature(
    args["dataset"], X_train, X_val, X_test, args["n_groups"]
)

In [63]:
col_trans = ColumnTransformer(
    [
        ("numeric", StandardScaler(), data.NUM_FEATURES[args["dataset"]]),
        (
            "categorical",
            OneHotEncoder(
                drop="if_binary", sparse_output=False, handle_unknown="ignore"
            ),
            data.CAT_FEATURES[args["dataset"]],
        ),
    ],
    verbose_feature_names_out=False,
)
col_trans.set_output(transform="pandas")

preprocess = Pipeline([("preprocess", col_trans)])
preprocess.fit(X_train)
X_train = preprocess.transform(X_train)
X_val = preprocess.transform(X_val)
X_test = preprocess.transform(X_test)

In [64]:
model_class = experiments.get_model(args["model_name"])

In [65]:
scorer = utils.get_combined_metrics_scorer(
    alpha=args["alpha"], performance_metric="bal_acc", fairness_metric="eod"
)

In [None]:
study = optuna.create_study(direction="maximize")
objective = lambda trial: experiments.run_trial(
    trial,
    scorer,
    X_train,
    Y_train,
    A_train,
    X_val,
    Y_val,
    A_val,
    model_class,
    experiments.get_param_spaces(args["model_name"]),
    args,
)
study.optimize(objective, n_trials=args["n_trials"])
best_params = study.best_params.copy()

In [None]:
model = model_class(**study.best_params)
model.fit(X_train, Y_train, A_train)
y_val_score = model.predict_proba(X_val)[:, 1]
thresh = utils.get_best_threshold(Y_val, y_val_score)
y_val_pred = y_val_score > thresh
y_test_score = model.predict_proba(X_test)[:, 1]
y_test_pred = y_test_score > thresh

In [68]:
eod = utils.equal_opportunity_score(Y_val, y_val_pred, A_val)
min_bal_acc = utils.min_balanced_accuracy(Y_val, y_val_pred, A_val)

print(f"EOD: {eod:.2f}")
print(f"Min. Bal. Acc.: {min_bal_acc:.2f}")
print(f"P(\hat Y = 1): {y_test_pred.mean():.2f}")

EOD: 0.00
Min. Bal. Acc.: 1.00
P(\hat Y = 1): 0.00




In [None]:
scorer = utils.get_combined_metrics_scorer(
    alpha=args["alpha"], performance_metric="bal_acc", fairness_metric="min_bal_acc"
)
study = optuna.create_study(direction="maximize")
objective = lambda trial: experiments.run_trial(
    trial,
    scorer,
    X_train,
    Y_train,
    A_train,
    X_val,
    Y_val,
    A_val,
    model_class,
    experiments.get_param_spaces(args["model_name"]),
    args,
)
study.optimize(objective, n_trials=args["n_trials"])
best_params = study.best_params.copy()

In [None]:
model = model_class(**study.best_params)
model.fit(X_train, Y_train, A_train)
y_val_score = model.predict_proba(X_val)[:, 1]
thresh = utils.get_best_threshold(Y_val, y_val_score)
y_val_pred = y_val_score > thresh
y_test_score = model.predict_proba(X_test)[:, 1]
y_test_pred = y_test_score > thresh

In [71]:
eod = utils.equal_opportunity_score(Y_val, y_val_pred, A_val)
min_bal_acc = utils.min_balanced_accuracy(Y_val, y_val_pred, A_val)

print(f"EOD: {eod:.2f}")
print(f"Min. Bal. Acc.: {min_bal_acc:.2f}")
print(f"P(\hat Y = 1): {y_test_pred.mean():.2f}")

EOD: 0.40
Min. Bal. Acc.: 0.21
P(\hat Y = 1): 0.61




## Tuning M²FGB and FairGBM with min_bal_acc

In [83]:
dataset = "german"
alpha = 0.7
model_name = "M2FGB"
fairness_metric = "min_bal_acc"
thresh = "ks"

args = {
    "dataset": dataset,
    "alpha": alpha,
    "output_dir": f"../results/comparing_metrics/{dataset}/{model_name}_{alpha}_{fairness_metric}",
    "model_name": model_name,
    "n_trials": 10,
    "n_groups": 4,
    "thresh" : thresh,
    "fairness_metric": fairness_metric,
}
experiments.run_subgroup_experiment(args)


dataset = "german"
alpha = 0.7
model_name = "FairGBMClassifier"
fairness_metric = "min_bal_acc"
thresh = "ks"

args = {
    "dataset": dataset,
    "alpha": alpha,
    "output_dir": f"../results/comparing_metrics/{dataset}/{model_name}_{alpha}_{fairness_metric}",
    "model_name": model_name,
    "n_trials": 10,
    "n_groups": 4,
    "thresh" : thresh,
    "fairness_metric": fairness_metric,
}
experiments.run_subgroup_experiment(args)

  0%|                                                                                                                                                                             | 0/10 [00:00<?, ?it/s]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
Stopped training because there are no more leaves that meet the split requirements
[LightGBM] [Info] Using self-defined objective function


 10%|████████████████▌                                                                                                                                                    | 1/10 [00:08<01:17,  8.58s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 20%|█████████████████████████████████                                                                                                                                    | 2/10 [00:15<01:03,  7.88s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
































[LightGBM] [Info] Using self-defined objective function


 30%|█████████████████████████████████████████████████▌                                                                                                                   | 3/10 [00:21<00:47,  6.79s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 40%|██████████████████████████████████████████████████████████████████                                                                                                   | 4/10 [00:25<00:34,  5.73s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 50%|██████████████████████████████████████████████████████████████████████████████████▌                                                                                  | 5/10 [00:32<00:30,  6.18s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function































 60%|███████████████████████████████████████████████████████████████████████████████████████████████████                                                                  | 6/10 [00:40<00:26,  6.64s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


























 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                 | 7/10 [00:44<00:17,  5.98s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
Stopped training because there are no more leaves that meet the split requirements
No further splits with positive gain, best gain: -inf


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                 | 8/10 [00:57<00:16,  8.31s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                | 9/10 [01:06<00:08,  8.39s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
























No further splits with positive gain, best gain: -inf












[LightGBM] [Info] Using self-defined objective function






100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:13<00:00,  7.38s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:12<00:00,  1.22s/it]


In [85]:
dataset = "adult"
alpha = 0.7
model_name = "M2FGB"
fairness_metric = "min_bal_acc"
thresh = "ks"

args = {
    "dataset": dataset,
    "alpha": alpha,
    "output_dir": f"../results/comparing_metrics/{dataset}/{model_name}_{alpha}_{fairness_metric}",
    "model_name": model_name,
    "n_trials": 10,
    "n_groups": 4,
    "thresh" : thresh,
    "fairness_metric": fairness_metric,
}
experiments.run_subgroup_experiment(args)



dataset = "adult"
alpha = 0.7
model_name = "FairGBMClassifier"
fairness_metric = "min_bal_acc"
thresh = "ks"

args = {
    "dataset": dataset,
    "alpha": alpha,
    "output_dir": f"../results/comparing_metrics/{dataset}/{model_name}_{alpha}_{fairness_metric}",
    "model_name": model_name,
    "n_trials": 10,
    "n_groups": 4,
    "thresh" : thresh,
    "fairness_metric": fairness_metric,
}
experiments.run_subgroup_experiment(args)

  0%|                                                                                                                                                                             | 0/10 [00:00<?, ?it/s]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 10%|████████████████▌                                                                                                                                                    | 1/10 [00:46<07:01, 46.89s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function




[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 30%|█████████████████████████████████████████████████▌                                                                                                                   | 3/10 [02:35<06:15, 53.62s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 40%|██████████████████████████████████████████████████████████████████                                                                                                   | 4/10 [03:25<05:12, 52.16s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 50%|██████████████████████████████████████████████████████████████████████████████████▌                                                                                  | 5/10 [04:01<03:51, 46.24s/it]





[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 60%|███████████████████████████████████████████████████████████████████████████████████████████████████                                                                  | 6/10 [04:59<03:22, 50.53s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                 | 7/10 [05:20<02:02, 40.70s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                 | 8/10 [06:21<01:34, 47.18s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                | 9/10 [07:06<00:46, 46.50s/it]

[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Using self-defined objective function


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [07:45<00:00, 46.54s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:22<00:00,  8.28s/it]


In [86]:
fairness_metric = "min_bal_acc"
alpha = 0.5
for dataset in ["german", "adult"]:
    print(f"Results for {dataset}")
    for model_name in ["FairGBMClassifier", "M2FGB"]:
        df = pd.read_csv(f"../results/comparing_metrics/{dataset}/{model_name}_{alpha}_{fairness_metric}/results.csv")

        print(model_name)

        print(f"\tMean bal_acc: {df['bal_acc'].mean():.3f} +- {df['bal_acc'].std():.3f}")
        print(f"\tMean min bal_acc: {df['bal_acc_min'].mean():.3f} +- {df['bal_acc_min'].std():.3f}")
        print(f"\tMean eod: {df['eod'].abs().mean():.3f} +- {df['eod'].abs().std():.3f}")

Results for german
FairGBMClassifier
	Mean bal_acc: 0.599 +- 0.113
	Mean min bal_acc: 0.398 +- 0.188
	Mean eod: 0.445 +- 0.290
M2FGB
	Mean bal_acc: 0.620 +- 0.054
	Mean min bal_acc: 0.446 +- 0.184
	Mean eod: 0.259 +- 0.270
Results for adult
FairGBMClassifier
	Mean bal_acc: 0.764 +- 0.024
	Mean min bal_acc: 0.727 +- 0.035
	Mean eod: 0.158 +- 0.071
M2FGB
	Mean bal_acc: 0.796 +- 0.009
	Mean min bal_acc: 0.729 +- 0.024
	Mean eod: 0.195 +- 0.064
