<a href="https://colab.research.google.com/github/mmfara/Adversarial-Debiasing-Enhanced/blob/main/PMF_Imputer_Function.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd

def fill_na_groupwise_PMF(df, columns, groupby_cols, missing_values=['?', np.nan]):
    """
    Fill missing values in one or more columns using group-wise PMF-based sampling.

    Parameters:
        df (pd.DataFrame): The input dataframe.
        columns (str or list): Column(s) to impute.
        groupby_cols (str or list): Column(s) to group by (protected attributes).
        missing_values (list): Values to consider as missing (e.g., ['?', np.nan]).

    Returns:
        pd.DataFrame: A copy of the dataframe with imputed values.
    """
    df = df.copy()

    if isinstance(columns, str):
        columns = [columns]
    if isinstance(groupby_cols, str):
        groupby_cols = [groupby_cols]

    for col in columns:
        df[col] = df[col].replace(missing_values, np.nan)

        # Group by the specified protected attributes
        grouped = df.groupby(groupby_cols)

        for group_keys, _ in grouped:
            if not isinstance(group_keys, tuple):
                group_keys = (group_keys,)

            mask = np.ones(len(df), dtype=bool)
            for col_name, val in zip(groupby_cols, group_keys):
                mask &= df[col_name] == val

            na_mask = mask & df[col].isna()
            value_counts = df.loc[mask & df[col].notna(), col].value_counts(normalize=True)

            if value_counts.empty:
                continue

            sampled = np.random.choice(value_counts.index, size=na_mask.sum(), p=value_counts.values)
            df.loc[na_mask, col] = sampled

    return df


####1. For a single protected attribute (e.g., race):

In [None]:
df_filled = fill_na_groupwise_PMF(df, columns=['occupation'], groupby_cols='race')

####2. For intersectional protected attributes (e.g., race + sex):

In [None]:
df_filled = fill_na_groupwise_PMF(df, columns=['occupation', 'workclass'], groupby_cols=['race', 'sex'])

####3. With custom missing values:

In [None]:
df_filled = fill_na_groupwise_PMF(df, columns='native.country', groupby_cols='race', missing_values=['?', 'Unknown', np.nan])