<a href="https://colab.research.google.com/github/mmfara/Disparate-Impact-Remover-Enhanced/blob/main/DIR%2B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import logging
from typing import List, Union, Set
from aif360.algorithms import Transformer


class DisparateImpactRemover(Transformer):
    """
    Enhanced Disparate Impact Remover with:
    - Intersectionality support
    - min_group_size filtering
    - groups_to_repair selective control
    - Global repair (single Repairer call)
    - Robust verbose logging
    """

    def __init__(self,
                 repair_level: float = 1.0,
                 sensitive_attribute: Union[str, List[str], None] = None,
                 min_group_size: int = 0,
                 groups_to_repair: Union[List[str], Set[str], None] = None,
                 verbose: bool = True):
        super().__init__()
        from BlackBoxAuditing.repairers.GeneralRepairer import Repairer
        self.Repairer = Repairer

        if not 0.0 <= repair_level <= 1.0:
            raise ValueError("'repair_level' must be between 0.0 and 1.0.")
        self.repair_level = repair_level
        self.min_group_size = max(0, min_group_size)
        self.groups_to_repair = set(groups_to_repair) if groups_to_repair else None
        self.verbose = verbose

        # Handle sensitive attributes
        if sensitive_attribute is None:
            self.sensitive_attributes = []
        elif isinstance(sensitive_attribute, str):
            self.sensitive_attributes = [sensitive_attribute]
        elif isinstance(sensitive_attribute, list):
            self.sensitive_attributes = sensitive_attribute
        else:
            raise TypeError("sensitive_attribute must be str, list, or None")

        # Robust logging setup
        self.logger = logging.getLogger("DisparateImpactRemover")
        if self.verbose:
            self.logger.setLevel(logging.INFO)
            if not self.logger.hasHandlers():
                handler = logging.StreamHandler()
                formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
                handler.setFormatter(formatter)
                self.logger.addHandler(handler)

    def fit_transform(self, dataset):
        features = dataset.features.copy()
        repaired = dataset.copy()

        if not self.sensitive_attributes:
            self.sensitive_attributes = dataset.protected_attribute_names[:1]

        protected_indices = [dataset.feature_names.index(attr) for attr in self.sensitive_attributes]

        # Generate readable group labels like "group=1" or "gender=0|race=1"
        combined_groups = self._get_group_labels(features, dataset)
        unique, counts = np.unique(combined_groups, return_counts=True)

        if self.verbose:
            self._log_group_stats(unique, counts)

        # Determine which groups to include
        valid_groups = []
        for grp, cnt in zip(unique, counts):
            reason = []
            if cnt < self.min_group_size:
                reason.append(f"size {cnt} < {self.min_group_size}")
            if self.groups_to_repair is not None and grp not in self.groups_to_repair:
                reason.append("not in groups_to_repair")

            if reason:
                self.logger.info(f"Skipped group: {grp} — {'; '.join(reason)}")
            else:
                valid_groups.append(grp)
                self.logger.info(f"Included group: {grp} (size={cnt})")

        # Select only rows from valid groups
        group_mask = np.isin(combined_groups, valid_groups)
        if not np.any(group_mask):
            raise ValueError("No valid rows to repair after filtering.")

        filtered_features = features[group_mask]
        filtered_labels = combined_groups[group_mask]

        # Map valid group labels to numeric codes
        label_to_code = {label: i for i, label in enumerate(sorted(set(filtered_labels)))}
        group_codes = np.array([label_to_code[label] for label in filtered_labels])

        # Append group code to filtered features
        repair_input = np.hstack([filtered_features, group_codes.reshape(-1, 1)])
        repair_index = repair_input.shape[1] - 1

        # Run Repairer
        repaired_filtered = np.array(
            self.Repairer(repair_input.tolist(), repair_index, self.repair_level, False).repair(repair_input.tolist()),
            dtype=np.float64
        )

        # FIXED: Apply repair only to valid rows (no shape mismatch!)
        repaired_features = features.copy()
        repaired_features[group_mask] = repaired_filtered[:, :-1]

        # Restore protected attributes
        for idx in protected_indices:
            repaired_features[:, idx] = dataset.features[:, idx]

        repaired.features = repaired_features
        return repaired

    def _get_group_labels(self, features: np.ndarray, dataset) -> np.ndarray:
        """Create human-readable group labels from protected attributes"""
        indices = [dataset.feature_names.index(attr) for attr in self.sensitive_attributes]
        protected_names = [dataset.feature_names[idx] for idx in indices]
        return np.array([
            '|'.join(f"{protected_names[j]}={features[i, idx]}"
                     for j, idx in enumerate(indices))
            for i in range(features.shape[0])
        ])

    def _log_group_stats(self, unique, counts):
        """Print stats for all groups"""
        self.logger.info("\n=== Group Analysis ===")
        self.logger.info(f"Protected attributes: {self.sensitive_attributes}")
        self.logger.info(f"Minimum group size: {self.min_group_size}")
        if self.groups_to_repair:
            self.logger.info(f"Specific groups to repair: {self.groups_to_repair}")
        self.logger.info(f"\n{'Status'.ljust(12)} {'Group'.ljust(40)} Size")
        for grp, cnt in sorted(zip(unique, counts), key=lambda x: -x[1]):
            status = []
            if cnt < self.min_group_size:
                status.append("TOO SMALL")
            if self.groups_to_repair and grp not in self.groups_to_repair:
                status.append("NOT SELECTED")
            status_str = "|".join(status) if status else "PROCESS"
            self.logger.info(f"{status_str.ljust(12)} {grp.ljust(40)} {cnt}")


In [None]:
from my_fairness_lib import DisparateImpactRemover

# Instantiate the enhanced DIR
dir_plus = DisparateImpactRemover(
    repair_level=1.0,                        # strength of repair
    sensitive_attribute=["race", "gender"],  # intersectional attributes
    min_group_size=30,                       # optional: skip small groups
    groups_to_repair={"race=1|gender=0", "race=1|gender=1"}  # optional: selective repair
)

# Apply it to a BinaryLabelDataset
repaired_dataset = dir_plus.fit_transform(dataset_orig_train)
