<a href="https://colab.research.google.com/github/mmfara/Disparate-Impact-Remover-Enhanced/blob/main/Disparate_Impact_Remover__Modified_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import logging
from itertools import product
from aif360.algorithms import Transformer


class DisparateImpactRemover(Transformer):
    """
    Extended Disparate Impact Remover:
    - Uses GeneralRepairer for rank computation (faithful to Feldman et al.)
    - Supports intersectional groups
    - Repairs feature values toward a shared global distribution
    - Preserves within-group rank
    """

    def __init__(self, repair_level=1.0, sensitive_attribute=None, min_group_size=20, verbose=True):
        super().__init__()
        from BlackBoxAuditing.repairers.GeneralRepairer import Repairer
        self.Repairer = Repairer

        self.repair_level = repair_level
        self.min_group_size = min_group_size
        self.verbose = verbose

        if not 0.0 <= repair_level <= 1.0:
            raise ValueError("'repair_level' must be between 0.0 and 1.0.")

        if isinstance(sensitive_attribute, str):
            self.sensitive_attributes = [sensitive_attribute]
        elif isinstance(sensitive_attribute, list):
            self.sensitive_attributes = sensitive_attribute
        elif sensitive_attribute is None:
            self.sensitive_attributes = []
        else:
            raise TypeError("sensitive_attribute must be str, list, or None")

        if self.verbose:
            logging.basicConfig(level=logging.INFO)

    def fit_transform(self, dataset):
        if not self.sensitive_attributes:
            self.sensitive_attributes = dataset.protected_attribute_names[:1]

        indices = [dataset.feature_names.index(attr) for attr in self.sensitive_attributes]
        protected_values = [np.unique(dataset.features[:, idx]) for idx in indices]
        group_combinations = list(product(*protected_values))

        repaired = dataset.copy()
        all_features = repaired.features.copy()

        feature_dim = all_features.shape[1]
        protected_set = set(indices)
        non_protected_indices = [i for i in range(feature_dim) if i not in protected_set]

        # Step 1: Collect data for global target distribution
        pooled_values = []
        group_indices = {}

        for group_vals in group_combinations:
            mask = np.logical_and.reduce([
                all_features[:, idx] == val for idx, val in zip(indices, group_vals)
            ])
            group_features = all_features[mask][:, non_protected_indices]

            if group_features.shape[0] < self.min_group_size:
                if self.verbose:
                    logging.warning(f"Skipping group {group_vals} (size={group_features.shape[0]})")
                continue

            pooled_values.append(group_features)
            group_indices[group_vals] = np.where(mask)[0]

        if not pooled_values:
            raise ValueError("No eligible groups met the minimum size requirement.")

        # Build global sorted target distribution
        pooled_array = np.vstack(pooled_values)
        global_sorted = np.sort(pooled_array, axis=0)

        # Step 2: Repair each group by aligning its ranks to global distribution
        for group_vals, idxs in group_indices.items():
            group_data = all_features[idxs][:, non_protected_indices]
            n = len(idxs)

            # Use GeneralRepairer to compute feature-wise ranks
            repairer = self.Repairer(group_data.tolist(), 0, 1.0, False)
            ranked = np.array(repairer.repair(group_data.tolist()))

            # Use the ranks to align to global distribution
            ranks = np.argsort(np.argsort(group_data, axis=0))
            aligned_by_rank = np.zeros_like(group_data)

            for col in range(group_data.shape[1]):
                aligned_by_rank[:, col] = global_sorted[ranks[:, col], col]

            # Interpolate: original + global target
            group_repaired = (
                (1 - self.repair_level) * group_data +
                self.repair_level * aligned_by_rank
            )

            all_features[idxs[:, None], non_protected_indices] = group_repaired

            if self.verbose:
                logging.info(f"Repaired group {group_vals} (size={n}) toward global target.")

        # Restore protected attributes
        for idx in indices:
            all_features[:, idx] = dataset.features[:, idx]

        repaired.features = all_features
        return repaired


In [None]:
# Applying DisparateImpactRemover
dir = DisparateImpactRemover(
    repair_level=0.4,
    sensitive_attribute=['race', 'sex'],
    min_group_size=502,
    verbose=True
)