<a href="https://colab.research.google.com/github/mmfara/Disparate-Impact-Remover-multiclass-repair/blob/main/Disparate_Impact_Remover_Modified_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from aif360.algorithms import Transformer

class DisparateImpactRemover(Transformer):
    """
    Modified Disparate Impact Remover that supports multi-valued protected attributes
    by applying the repair process independently within each protected group.
    """

    def __init__(self, repair_level=1.0, sensitive_attribute=''):
        super().__init__()  # ✅ Ensures compatibility with aif360.Transformer

        from BlackBoxAuditing.repairers.GeneralRepairer import Repairer
        self.Repairer = Repairer

        if not 0.0 <= repair_level <= 1.0:
            raise ValueError("'repair_level' must be between 0.0 and 1.0.")

        self.repair_level = repair_level

        # Support string or list of sensitive attributes (default: 1)
        if isinstance(sensitive_attribute, list):
            self.sensitive_attribute = sensitive_attribute[0]  # Only one supported for now
        else:
            self.sensitive_attribute = sensitive_attribute

    def fit_transform(self, dataset):
        if not self.sensitive_attribute:
            self.sensitive_attribute = dataset.protected_attribute_names[0]

        index = dataset.feature_names.index(self.sensitive_attribute)
        unique_groups = np.unique(dataset.features[:, index])

        repaired = dataset.copy()
        new_features = np.zeros_like(repaired.features)

        for val in unique_groups:
            group_mask = dataset.features[:, index] == val
            group_features = repaired.features[group_mask]

            # Apply repair only to non-protected attributes
            repairer = self.Repairer(group_features.tolist(), index, self.repair_level, False)
            repaired_group_features = repairer.repair(group_features.tolist())

            new_features[group_mask] = np.array(repaired_group_features)

        # Overwrite features in the repaired dataset
        repaired.features = new_features

        # Ensure protected attribute column remains unchanged
        repaired.features[:, index] = dataset.features[:, index]

        return repaired


# Applying DisparateImpactRemover
dir = DisparateImpactRemover(repair_level=0.8, sensitive_attribute="group")