In [13]:
import sys
sys.path.append('../../../')

## Template - Bias Mitigation Benchmark ([Holistic AI](https://research.holisticai.com))

**Task:** Clustering

**Type:** Preprocessing


This notebook is a template for the Bias Mitigation Benchmark. It can be used to mitigate bias in datasets and models. The notebook is based on the [Holistic AI open source library](https://github.com/holistic-ai/holisticai) and follows the bias mitigation benchmark outlined in [Holistic AI](https://research.holisticai.com).

### Template Structure

The template have the following steps:

1. Setup definition: 
    - select a task: `binary_classification`, `multiclass_classification`, `regression`, `clustering`, `recommender`
    - select a type: `inprocessing`, `preprocessing`, `postprocessing`
2. Mitigator class
    - create a class for you custom mitigator
3. Evaluation
    - evaluate your mitigator and compare it with other mitigators
4. Submission
    - do you have good results? Then submit your mitigator to the Bias Mitigation Benchmark


### Step 1: Setup Definition

In [14]:
from holisticai.benchmark.tasks import task_name, get_task

print(task_name)

['binary_classification', 'multiclass_classification', 'regression', 'clustering', 'recommender']


In [15]:
# load a task
task = get_task("clustering")

In [16]:
# benchmark for the task by type
task.benchmark(type='preprocessing')

Dataset,Average Cluster Balance,heart
Mitigator,Unnamed: 1_level_1,Unnamed: 2_level_1
FairletClusteringPreprocessing,0.948859,0.948859


### Step 2: Mitigator Class

In [17]:
from typing import Optional, Union

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.metrics.pairwise import pairwise_distances_argmin

from holisticai.bias.mitigation.commons.fairlet_clustering.decompositions import (
    DecompositionMixin,
    ScalableFairletDecomposition,
    VanillaFairletDecomposition,
)
from holisticai.utils.models.cluster import KCenters, KMedoids
from holisticai.utils.transformers.bias import BMPreprocessing as BMPre

DECOMPOSITION_CATALOG = {
    "Scalable": ScalableFairletDecomposition,
    "Vanilla": VanillaFairletDecomposition,
}
CLUSTERING_CATALOG = {"KCenter": KCenters, "KMedoids": KMedoids}


class MyPreprocessingMitigator(BaseEstimator, BMPre):

    def __init__(
        self,
        decomposition: Union["str", "DecompositionMixin"] = "Vanilla",
        p: Optional[str] = 1,
        q: Optional[float] = 3,
        seed: Optional[int] = None,
    ):
        """
        Parameters
        ----------
            decomposition : str
                Fairlet decomposition strategy, available: Vanilla, Scalable, MCF

            p : int
                fairlet decomposition parameter for Vanilla and Scalable strategy

            q : int
                fairlet decomposition parameter for Vanilla and Scalable strategy

            seed : int
                Random seed.
        """
        self.decomposition = DECOMPOSITION_CATALOG[decomposition](p=p, q=q)
        self.p = p
        self.q = q
        self.seed = seed

    def fit_transform(
        self,
        X: np.ndarray,
        group_a: np.ndarray,
        group_b: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ):
        params = self._load_data(
            X=X, sample_weight=sample_weight, group_a=group_a, group_b=group_b
        )
        X = params["X"]
        sample_weight = params["sample_weight"]
        group_a = params["group_a"].astype("int32")
        group_b = params["group_b"].astype("int32")
        np.random.seed(self.seed)
        fairlets, fairlet_centers, fairlet_costs = self.decomposition.fit_transform(
            X, group_a, group_b
        )
        Xt = np.zeros_like(X)
        mapping = np.zeros(len(X), dtype="int32")
        centers = np.array([X[fairlet_center] for fairlet_center in fairlet_centers])
        for i, fairlet in enumerate(fairlets):
            Xt[fairlet] = X[fairlet_centers[i]]
            mapping[fairlet] = i
            sample_weight[fairlet] = len(fairlet) / len(X)

        self.update_estimator_param("sample_weight", sample_weight)
        self.sample_weight = sample_weight
        self.X = X
        self.mapping = mapping
        self.centers = centers
        return Xt

    def transform(self, X):
        fairlets_midxs = pairwise_distances_argmin(X, Y=self.X)
        return self.centers[self.mapping[fairlets_midxs]]


### Step 3: Evaluation

In [18]:
my_mitigator = MyPreprocessingMitigator()

task.run_benchmark(mitigator = my_mitigator, type = 'preprocessing')

Clustering Benchmark initialized for MyPreprocessingMitigator


100%|██████████| 1/1 [00:00<00:00,  2.90it/s]


In [19]:
task.evaluate_table()

Dataset,Average Cluster Balance,heart
Mitigator,Unnamed: 1_level_1,Unnamed: 2_level_1
MyPreprocessingMitigator,0.95398,0.95398
FairletClusteringPreprocessing,0.948859,0.948859


### Step 4: Submission

In [20]:
task.submit()

Opening the link in your browser:
