In [16]:
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
from sklearn.utils import Bunch

from fairlearn.postprocessing._constants import LABEL_KEY, SCORE_KEY
from fairlearn.postprocessing._threshold_operation import ThresholdOperation

import xgboost as xgb

from sklearn.metrics import f1_score
from itertools import product
from scipy.spatial.distance import cdist

from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

import warnings

from sklearn.metrics import balanced_accuracy_score, roc_auc_score, accuracy_score, f1_score, roc_curve

from fairlearn.metrics import equalized_odds_ratio

from ACSIncomeDataLoader import  load_synthetic_data

from scipy.spatial.distance import cdist

from typing import Callable

import itertools

import plotly.express as px
import plotly.graph_objects as go

from sklearn.dummy import DummyClassifier



In [17]:
def _extend_confusion_matrix(*, true_positives, false_positives, true_negatives, false_negatives):
    return Bunch(
        true_positives=true_positives,
        false_positives=false_positives,
        true_negatives=true_negatives,
        false_negatives=false_negatives,
        predicted_positives=(true_positives + false_positives),
        predicted_negatives=(true_negatives + false_negatives),
        positives=(true_positives + false_negatives),
        negatives=(true_negatives + false_positives),
        n=(true_positives + true_negatives + false_positives + false_negatives),
    )


METRIC_DICT = {
    "selection_rate": lambda x: x.predicted_positives / x.n if x.n > 0 else 0,
    
    "false_positive_rate": lambda x: (
        x.false_positives / x.negatives if x.negatives > 0 else 0
    ),
    
    "false_negative_rate": lambda x: (
        x.false_negatives / x.positives if x.positives > 0 else 0
    ),
    
    "true_positive_rate": lambda x: (
        x.true_positives / x.positives if x.positives > 0 else 0
    ),
    
    "true_negative_rate": lambda x: (
        x.true_negatives / x.negatives if x.negatives > 0 else 0
    ),
    
    "accuracy_score": lambda x: (
        (x.true_positives + x.true_negatives) / x.n if x.n > 0 else 0
    ),
    
    "balanced_accuracy_score": lambda x: (
        0.5 * (x.true_positives / x.positives if x.positives > 0 else 0) +
        0.5 * (x.true_negatives / x.negatives if x.negatives > 0 else 0)
    ),
    
    "negative_predictive_value": lambda x: (
        x.true_negatives / (x.true_negatives + x.false_negatives) 
        if (x.true_negatives + x.false_negatives) > 0 else 0
    ),
    
    "precision": lambda x: (
        x.true_positives / (x.true_positives + x.false_positives)
        if (x.true_positives + x.false_positives) > 0 else 0
    ),
    
    "recall": lambda x: (
        x.true_positives / (x.true_positives + x.false_negatives)
        if (x.true_positives + x.false_negatives) > 0 else 0
    ),
    
    "f1_score": lambda x: (
        2 * (METRIC_DICT["precision"](x) * METRIC_DICT["recall"](x)) /
        (METRIC_DICT["precision"](x) + METRIC_DICT["recall"](x))
        if (METRIC_DICT["precision"](x) + METRIC_DICT["recall"](x)) > 0 else 0
    )
}



In [18]:
def plot_ROC(data: dict, calculate_ROC_points: Callable[[dict], list]):
    plt.figure(figsize=(8, 6))
    for group_id, group_data in data.groupby('group'):
        x, y, operation_list, objective_list, actual_counts_list, metrics_list = calculate_ROC_points(group_data)
        plt.plot(x, y, marker='o', linestyle='-', label=f'Group {group_id}')
    
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel('ADJ False Positive Rate')
    plt.ylabel('ADJ True Positive Rate')
    plt.title('ROC Curve by Group')
    plt.legend()
    plt.grid(True)
    plt.show()
    return operation_list, objective_list, actual_counts_list, metrics_list


def  plot_ROC_plotly(data: dict, calculate_ROC_points: Callable[[dict], list]):

    fig = go.Figure()
    for group_id, group_data in data.groupby('group'):
        x, y, operation_list, objective_list, actual_counts_list, metrics_list = calculate_ROC_points(
            group_data, 
            x_metric="false_positive_rate", 
            y_metric="true_positive_rate", 
            obj_metric="balanced_accuracy_score"
        )
        fig.add_trace(go.Scatter(x=x, y=y, mode='lines+markers', name=f'Group {group_id}'))
    
    fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random', line=dict(dash='dash')))
    fig.update_layout(title='ROC Curve by Group', xaxis_title='ADJ False Positive Rate', yaxis_title='ADJ True Positive Rate')
    fig.show()
    return operation_list, objective_list, actual_counts_list, metrics_list


In [19]:
def create_interpolation_dict(thresholds_dict):
    interpolation_dict = {}
    for group, threshold in thresholds_dict.items():
        op = ThresholdOperation(">", threshold)
        # read from the repo 
        interpolation_dict[group] = Bunch(p0=1.0, operation0=op, p1=0.0, operation1=op)
    return interpolation_dict


def print_metrics(y_test, preds, group_test):
    print('Balanced Accuracy: ', balanced_accuracy_score(y_test, preds))
    print('ROC AUC: ', roc_auc_score(y_test, preds))
    print('Accuracy: ', accuracy_score(y_test, preds))
    print('F1 Score: ', f1_score(y_test, preds))
    print('Equalized Odds Ratio: ', equalized_odds_ratio(y_test, preds, sensitive_features=group_test))
    print()


def _get_counts(labels):
    """Return the overall, positive, and negative counts of the labels.

    :param labels: the labels of the samples
    :type labels: list
    :return: a tuple containing the overall, positive, and negative counts of the labels
    :rtype: tuple of int, int, int
    """
    n = len(labels)
    n_positive = sum(labels)
    n_negative = n - n_positive
    return n, n_positive, n_negative


In [20]:
def calculate_ROC_points(data, 
                         x_metric="false_positive_rate", 
                         y_metric="true_positive_rate", 
                         obj_metric="balanced_accuracy_score"):
    scores = data['score'].tolist()
    labels = data['label'].tolist()

    data_sorted = data.sort_values(by=SCORE_KEY, ascending=False)

    scores = list(data_sorted[SCORE_KEY])
    labels = list(data_sorted[LABEL_KEY])

    n, n_positive, n_negative = _get_counts(labels)
    
    scores.append(-np.inf)
    labels.append(np.nan)
    
    thresholds = np.unique(scores)
    thresholds = np.append(thresholds, thresholds[-1] + 1) 
    
    i = 0
    # count[0] -> false_positives 
    # count[1] -> true_positives 
    count = [0, 0]

    x_list, y_list, operation_list, objective_list, actual_counts_list, metrics_list = [], [], [], [], [], []
    while i < n:
        # special handling of the initial point
        if x_list == []:
            threshold = np.inf
        else:
            threshold = scores[i]
            
            # calculate tp and fp for every threshold value
            while scores[i] == threshold:
                count[labels[i]] += 1
                i += 1
            threshold = (threshold + scores[i]) / 2

        # calculate the rest of the metrics and store it
        actual_counts = _extend_confusion_matrix(
            false_positives=count[0],
            true_positives=count[1],
            true_negatives=(n_negative - count[0]),
            false_negatives=(n_positive - count[1]),
        )
        actual_counts_list.append(actual_counts)

        operations = [(">", actual_counts)]

        '''
        x = fpr , y = tpr , obj = balanced accuracy score
        operation: a 'function' which compares given argument with threshold using operation_string
        metrics: a dictionary of answers for all the metrics in METRIC_DICT
        ''' 
        for operation_string, counts in operations:
            if isinstance(x_metric, str) and isinstance(y_metric, str) and isinstance(obj_metric, str):
                x = METRIC_DICT[x_metric](counts)
                y = METRIC_DICT[y_metric](counts)
                obj = METRIC_DICT[obj_metric](counts)
                metrics = {metric: METRIC_DICT[metric](counts) for metric in METRIC_DICT}

                operation = ThresholdOperation(operation_string, threshold)

                x_list.append(x)
                y_list.append(y)
                operation_list.append(operation)
                objective_list.append(obj)
                metrics_list.append(metrics)
            else:
                raise ValueError("Metrics must be specified as strings corresponding to keys in METRIC_DICT.")
    
    return x_list, y_list, operation_list, objective_list, actual_counts_list, metrics_list


In [21]:

class ApproxThresholdBrute():
    def __init__(self, 
                y_score,
                y_true,
                group_assignments, 
                METRIC_DICT, 
                lambda_=0.5,
                global_metric="f1_score", 
                max_epsilon=1.0,
                subsample=10000):
        
        self.METRIC_DICT = METRIC_DICT 
        self.y_score = y_score
        self.y_true = y_true
        self.group_assignments = group_assignments

        self.lambda_ = lambda_
        self.thresholds_ = None
        self.epsilons_ = None
        self.global_metric = global_metric
        self.max_epsilon = max_epsilon
        self.best_objective_value = None

        self.data = pd.DataFrame({
            'score': y_score,
            'label': y_true,
            'group': group_assignments
        })

        if len(self.data) > 10000 and subsample:
            # raise warning brute force is slow
            warnings.warn("Brute force is slow for large datasets."+\
                        " Subsampling to 10000 samples. \nYou "+\
                        "can disable this by setting subsample=False.")
            self.data = self.data.sample(10000)

        unique_groups = self.data['group'].unique()

        self.metrics_per_group_per_threshold = {}
        all_threshold_lists = []
    
        for g in tqdm(unique_groups):
            group_data = self.data[self.data['group'] == g]
            
            _, _, operation_list, _ , _ , metrics_list = calculate_ROC_points(group_data)

            for all_metrics, operation in zip(metrics_list, operation_list):
                if g not in self.metrics_per_group_per_threshold:
                    self.metrics_per_group_per_threshold[g] = {}

                # Debugger 
                if not all_metrics:  
                    raise ValueError(f"Invalid metrics computed for {g} with thresh {operation.threshold}")
                
                self.metrics_per_group_per_threshold[g][operation.threshold] = all_metrics
                                
            
            all_threshold_lists.append(operation_list)
        
        
        possible_thresholds_combos = self.product_approach(all_threshold_lists)
 
        
        global_metrics_per_thresh_combo = {}
        metric_vectors_per_thresh_combo = {}
        group_size_vector = [len(self.data[self.data['group'] == g]) for g in unique_groups]
        
        for threshold_combo in possible_thresholds_combos:   
            metrics_per_group = {}
            global_metric_per_group = {}

            for i, threshold in enumerate(threshold_combo):
                '''
                for given threshold and its index, find the respective metric values from metrics_per_group_per_threshold
                and assign it to global_metric_per_group

                assign global_metric_per_group to global_metrics_per_thresh_combo[tuple(threshold_combo)] 
                '''

                thresh = float(threshold.threshold)
                group = unique_groups[i]

                if group not in self.metrics_per_group_per_threshold or thresh not in self.metrics_per_group_per_threshold[group]:
                    raise ValueError(f"Threshold {thresh} not found for group {group}")
        
                metrics_per_group[group] = self.metrics_per_group_per_threshold[group][thresh]
                global_metric_per_group[group] = self.metrics_per_group_per_threshold[group][thresh][self.global_metric]

            global_metrics_per_thresh_combo[tuple(threshold_combo)] = global_metric_per_group
            metric_vectors_per_thresh_combo[tuple(threshold_combo)] = metrics_per_group

       
        def get_metric_vectors_for_each_group(threshold_combo, metric_vectors_per_thresh_combo):
            metric_vectors_for_each_group = []

            for metric_dict in metric_vectors_per_thresh_combo[tuple(threshold_combo)].values():
    
                vector = np.array([float(val) for val in metric_dict.values()])
                metric_vectors_for_each_group.append(vector)

            return metric_vectors_for_each_group

        
        all_objective_values = []
        for threshold_combo in possible_thresholds_combos:
            objective_value, thresholds_used, epsilon_diffs = self.objective_func(
                global_metrics_per_thresh_combo[tuple(threshold_combo)],
                group_size_vector,
                get_metric_vectors_for_each_group(threshold_combo,metric_vectors_per_thresh_combo),
                threshold_combo,
                lambda_val=self.lambda_,
                diff_strategy='euclidean'
            )
            all_objective_values.append((objective_value, thresholds_used, epsilon_diffs))


        sorted_objective_values = sorted(all_objective_values, key=lambda x: x[0])
        
        best_objective_value, best_threshold_combo, best_epsilon_diff = sorted_objective_values[0]

        self.best_objective_value = best_objective_value
        self.best_threshold_combo = best_threshold_combo
        self.best_epsilon_diff = best_epsilon_diff

        print("Best Objective Value:", best_objective_value)
        print("Best Threshold Combination:", best_threshold_combo)
        # print("Best Epsilon Difference:", best_epsilon_diff)

    @staticmethod
    def objective_func(global_metric_for_each_group,
                    group_size_vector,
                    metric_vectors_for_each_group, 
                    threshold_combo,
                    lambda_val=0.5,
                    diff_strategy='euclidean'):

        group_metrics = np.array([global_metric_for_each_group[group] for group in global_metric_for_each_group.keys()])
        
        
        global_performance_weighted = np.average(group_metrics, weights=group_size_vector)

        objective = 0
        epsilon_diffs = {}

        # Check and adjust dimensions of vectors
        for i, group_vector in enumerate(metric_vectors_for_each_group):
            for j, other_group_vector in enumerate(metric_vectors_for_each_group):
                if i != j:
                    # Shapes match?
                    if group_vector.shape != other_group_vector.shape:
                        raise ValueError(f"Dimension mismatch: Group {i} ({group_vector.shape}) vs Group {j} ({other_group_vector.shape})")

                    if diff_strategy == 'euclidean':
                        distances = cdist(
                            np.array(group_vector).reshape(1, -1), 
                            np.array(other_group_vector).reshape(1, -1), 
                            metric='euclidean'
                        )
                    elif diff_strategy == 'cosine':
                        distances = np.dot(group_vector, other_group_vector) / (
                            np.linalg.norm(group_vector) * np.linalg.norm(other_group_vector))
                    else:
                        raise ValueError(f"Invalid diff strategy: {diff_strategy}")
                    
                    objective += np.sum(distances)
                    epsilon_diff = np.abs(group_vector - other_group_vector)
                    epsilon_diffs[(i, j)] = epsilon_diff

        
        n_unique_groups = len(metric_vectors_for_each_group)
        A_choose_2 = n_unique_groups * (n_unique_groups - 1) / 2

        
        objective = (1 - lambda_val) * (objective / A_choose_2) + lambda_val * (1 - global_performance_weighted)
        return objective, threshold_combo, epsilon_diffs

    @staticmethod
    def product_approach(all_lists, use_numpy=True):
        def numpy_cartesian_product(lists):
            arrays = [np.array(lst) for lst in lists]
            grid = np.meshgrid(*arrays, indexing='ij')
            return np.stack(grid, axis=-1).reshape(-1, len(arrays))

        if use_numpy:
            return numpy_cartesian_product(all_lists)
        else:
            return list(product(*all_lists))

    @staticmethod
    def _get_counts(labels):
        """Return the overall, positive, and negative counts of the labels.

        :param labels: the labels of the samples
        :type labels: list
        :return: a tuple containing the overall, positive, and negative counts of the labels
        :rtype: tuple of int, int, int
        """
        n = len(labels)
        n_positive = sum(labels)
        n_negative = n - n_positive
        return n, n_positive, n_negative

    @staticmethod
    def _extend_confusion_matrix( 
            *, true_positives, false_positives, true_negatives, false_negatives
        ):
            return Bunch(
                true_positives=true_positives,
                false_positives=false_positives,
                true_negatives=true_negatives,
                false_negatives=false_negatives,
                predicted_positives=(true_positives + false_positives),
                predicted_negatives=(true_negatives + false_negatives),
                positives=(true_positives + false_negatives),
                negatives=(true_negatives + false_positives),
                n=(true_positives + true_negatives + false_positives + false_negatives),
            )

    # haven't used this function 
    @staticmethod
    def calculate_ROC_points(data, METRIC_DICT):
        data = data.copy()
        data_sorted = data.sort_values(by=SCORE_KEY, ascending=False)

        scores = list(data_sorted['score'])
        labels = list(data_sorted['label'])

        n, n_positive, n_negative = ApproxThresholdBrute._get_counts(labels)
        
        # Use actual score values for thresholds
        thresholds = np.unique(scores)
        operation_list, actual_counts_list, metrics_list = [], [], []

        for threshold in thresholds:
            predictions = (scores > threshold).astype(int)
            tp = np.sum((predictions == 1) & (labels == 1))
            fp = np.sum((predictions == 1) & (labels == 0))
            tn = np.sum((predictions == 0) & (labels == 0))
            fn = np.sum((predictions == 0) & (labels == 1))

            actual_counts = ApproxThresholdBrute._extend_confusion_matrix(
                false_positives=fp,
                true_positives=tp,
                true_negatives=tn,
                false_negatives=fn,
            )
            
            metrics = {metric: METRIC_DICT[metric](actual_counts) for metric in METRIC_DICT}
            operation = ThresholdOperation(">", float(threshold))
            
            operation_list.append(operation)
            actual_counts_list.append(actual_counts)
            metrics_list.append(metrics)

        return operation_list, actual_counts_list, metrics_list

    @staticmethod
    def generate_batch_combinations(possible_thresholds, start, end):
        """Generate combinations for a batch of thresholds"""
        first_group_thresholds = possible_thresholds[0][start:end]
        other_groups_thresholds = possible_thresholds[1:]
        
        for combo in itertools.product(first_group_thresholds, *other_groups_thresholds):
            yield combo

    def f1_score(self, metric):
        """
        metric: a dictionary of metrics for each group

        """
        # Extract the dictionary from the NumPy array
        return (metric.item()['f1_score'])
       


In [26]:
def run_test():
    X = pd.DataFrame([[1], [1], [0], [0], [1], [1], [0], [0]], columns=["feature"])
    Y = pd.DataFrame([0, 0, 0, 0, 1, 1, 1, 1], columns=["label"])  
    A = pd.DataFrame([1, 1, 0, 0, 0, 0, 1, 1], columns=["group"]) 

    y = Y.values.ravel()
    X_train, X_test, y_train, y_test, A_train, A_test = train_test_split(
        X, y, A, test_size=0.2, random_state=42
    )

    dummy_clf = DummyClassifier(strategy="stratified", random_state=42)
    dummy_clf.fit(X_train, y_train)

    preds_proba = dummy_clf.predict_proba(X_train)[:, 1]
    y_pred = dummy_clf.predict(X_test)

    def compute_fpr_tpr_objective(y_true, y_pred, sensitive_features):
        unique_groups = np.unique(sensitive_features)
        tpr_per_group = {}
        fpr_per_group = {}

        # Compute TPR and FPR for each group
        for group in unique_groups:
            indices = (sensitive_features == group)
            y_true_group = y_true[indices]
            y_pred_group = y_pred[indices]

            tn, fp, fn, tp = confusion_matrix(y_true_group, y_pred_group).ravel()
            tpr_per_group[group] = tp / (tp + fn) if (tp + fn) > 0 else 0
            fpr_per_group[group] = fp / (fp + tn) if (fp + tn) > 0 else 0

        
        tpr_diff = max(tpr_per_group.values()) - min(tpr_per_group.values())
        fpr_diff = max(fpr_per_group.values()) - min(fpr_per_group.values())
        objective_value = max(tpr_diff, fpr_diff)

        return tpr_per_group, fpr_per_group, objective_value

    tpr_per_group, fpr_per_group, dummy_objective_value = compute_fpr_tpr_objective(
        y_train, dummy_clf.predict(X_train), A_train.values.ravel()
    )

    print("Dummy Classifier Results:")
    print(f"TPR per group: {tpr_per_group}")
    print(f"FPR per group: {fpr_per_group}")
    print(f"Objective Value: {dummy_objective_value}")


    data = pd.DataFrame({
        "score": preds_proba,    
        "label": y_train,        
        "group": A_train.values.ravel()  
    })

    at = ApproxThresholdBrute(
        y_score=preds_proba,
        y_true=y_train,
        group_assignments=A_train.values.ravel(),
        METRIC_DICT=METRIC_DICT,
        lambda_=0.5
    )

    best_objective_value = at.best_objective_value
    print("ApproxThresholdBrute Results:")
    print(f"Best Objective Value: {best_objective_value}")

   

In [27]:
run_test()

Dummy Classifier Results:
TPR per group: {0: 0.0, 1: 0.5}
FPR per group: {0: 0.5, 1: 1.0}
Objective Value: 0.5


  0%|          | 0/2 [00:00<?, ?it/s]

Best Objective Value: 0.7337684871413404
Best Threshold Combination: [[>-inf] [>-inf]]
ApproxThresholdBrute Results:
Best Objective Value: 0.7337684871413404


In [25]:
X, Y, A = load_synthetic_data()

random_indices = np.random.choice(X.index, size=100, replace=False)
X = X.loc[random_indices] 
Y = Y.loc[random_indices] 
A = A.loc[random_indices]

y = Y.values

X_train, X_test, y_train, y_test, A_train, A_test = train_test_split( X, y, A, test_size=0.2, random_state=42)
model = xgb.XGBClassifier()

model.fit(X_train, y_train)
preds_proba = model.predict_proba(X_train)[:, 1]

data = pd.DataFrame({
        'score': preds_proba,
        'label': y_train,
        'group': A_train.values 
        })

at = ApproxThresholdBrute(y_score=preds_proba, y_true=y_train, group_assignments=A_train.values, METRIC_DICT=METRIC_DICT)

# operation_list, objective_list, actual_counts_list, metrics_list = plot_ROC_plotly(data, calculate_ROC_points)

print_metrics(y_test, model.predict(X_test), A_test)


  0%|          | 0/2 [00:00<?, ?it/s]

Best Objective Value: 0.09595959595959597
Best Threshold Combination: [[>0.46199627220630646] [>0.5055713281035423]]
Balanced Accuracy:  0.7708333333333333
ROC AUC:  0.7708333333333333
Accuracy:  0.75
F1 Score:  0.761904761904762
Equalized Odds Ratio:  0.0

