<a href="https://colab.research.google.com/github/mmfara/Disparate-Impact-Remover-multiclass-repair/blob/main/Disparate_Impact_Remover__Modified_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import logging
from aif360.algorithms import Transformer
from itertools import product


class DisparateImpactRemover(Transformer):
    """
    Disparate Impact Remover that supports intersectional protected attributes.
    Applies group-wise repair independently per intersectional group.
    """

    def __init__(self, repair_level=1.0, sensitive_attribute=None, min_group_size=20, verbose=True):
        """
        Args:
            repair_level (float): Degree of repair (0.0 = none, 1.0 = full).
            sensitive_attribute (str or list): One or more protected attribute names.
            min_group_size (int): Minimum number of samples in a group to apply repair.
            verbose (bool): Enables logging if True.
        """
        super().__init__()
        from BlackBoxAuditing.repairers.GeneralRepairer import Repairer
        self.Repairer = Repairer

        if not 0.0 <= repair_level <= 1.0:
            raise ValueError("'repair_level' must be between 0.0 and 1.0.")

        if isinstance(sensitive_attribute, str):
            self.sensitive_attributes = [sensitive_attribute]
        elif isinstance(sensitive_attribute, list):
            self.sensitive_attributes = sensitive_attribute
        elif sensitive_attribute is None:
            self.sensitive_attributes = []  # Will default to first from dataset
        else:
            raise TypeError("sensitive_attribute must be str, list, or None")

        self.repair_level = repair_level
        self.min_group_size = min_group_size
        self.verbose = verbose

        if self.verbose:
            logging.basicConfig(level=logging.INFO)

    def fit_transform(self, dataset):
        # Default to first protected attribute if none provided
        if not self.sensitive_attributes:
            self.sensitive_attributes = dataset.protected_attribute_names[:1]

        # Get column indices of protected attributes
        indices = [dataset.feature_names.index(attr) for attr in self.sensitive_attributes]
        protected_values = [np.unique(dataset.features[:, idx]) for idx in indices]

        # All combinations of protected attribute values (intersectional groups)
        group_combinations = list(product(*protected_values))

        repaired = dataset.copy()
        new_features = np.zeros_like(repaired.features)

        for group_vals in group_combinations:
            # Select rows that match the current group combination
            mask = np.logical_and.reduce([
                dataset.features[:, idx] == val for idx, val in zip(indices, group_vals)
            ])

            group_features = repaired.features[mask]

            if group_features.shape[0] < self.min_group_size:
                if self.verbose:
                    logging.warning(f"Skipping group {group_vals} (size={group_features.shape[0]})")
                new_features[mask] = group_features
                continue

            try:
                # Pass one of the protected indices (any will do, group is already filtered)
                repairer = self.Repairer(group_features.tolist(), indices[0], self.repair_level, False)
                repaired_group_features = repairer.repair(group_features.tolist())
                new_features[mask] = np.array(repaired_group_features)

                if self.verbose:
                    logging.info(f"Repaired group {group_vals} (size={group_features.shape[0]})")
            except Exception as e:
                logging.error(f"Repair failed for group {group_vals}: {e}")
                new_features[mask] = group_features  # fallback to original

        # Restore protected attribute columns to ensure they remain unchanged
        for idx in indices:
            new_features[:, idx] = dataset.features[:, idx]

        repaired.features = new_features
        return repaired

In [None]:
remover = DisparateImpactRemover(
    repair_level=0.8,
    sensitive_attribute=['gender', 'race'],
    min_group_size=30,
    verbose=True
)

repaired_data = remover.fit_transform(dataset)