<a href="https://colab.research.google.com/github/mmfara/Adversarial-Debiasing-Enhanced/blob/main/PMF_Imputer_Function.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 🔍 What is PMF (Probability Mass Function)?

**PMF** stands for **Probability Mass Function**. It's a statistical function that gives the **probability of each unique value** of a **categorical** variable.

### 📦 In the context of data imputation:

PMF represents the **relative frequencies** of the non-missing values in a column — we use these to **randomly fill in missing values**.

---

### 🧠 Example

Let's say we have a column `"occupation"` with the following values:

| Occupation       | Count |
|------------------|-------|
| Tech-support     | 30    |
| Sales            | 50    |
| Exec-managerial  | 20    |

We compute the PMF as:

| Occupation       | Probability |
|------------------|-------------|
| Tech-support     | 30 / 100 = 0.30 |
| Sales            | 50 / 100 = 0.50 |
| Exec-managerial  | 20 / 100 = 0.20 |

So if a value is missing, we randomly sample one using these probabilities:
- 30% chance → `Tech-support`
- 50% chance → `Sales`
- 20% chance → `Exec-managerial`

---

### ✅ Why Use PMF for Imputation?

- 🔁 **Preserves distribution** of real values.
- 🧑🏽‍🤝‍🧑🏿 **Group-aware**: If used with protected attributes (like `race`, `gender`), it respects fairness by using **group-specific** distributions.
- 📊 Better than just using `mode` or `mean`, especially when fairness matters.

---

### 💡 Bonus: PMF vs Other Imputation

| Method        | Handles Categorical | Group-Aware | Keeps Distribution |
|---------------|----------------------|--------------|---------------------|
| Mode          | ✅                   | ❌           | ❌                  |
| Mean (numeric only) | ❌             | ❌           | ❌                  |
| PMF Sampling  | ✅                   | ✅           | ✅                  |

---

Use PMF-based sampling when your goal is **statistically sound** and **fair** imputation of missing categorical data!


In [None]:
import numpy as np
import pandas as pd

def fill_na_groupwise_PMF(df, columns, groupby_cols, missing_values=['?', np.nan]):
    """
    Fill missing values in one or more columns using group-wise PMF-based sampling.

    Parameters:
        df (pd.DataFrame): The input dataframe.
        columns (str or list): Column(s) to impute.
        groupby_cols (str or list): Column(s) to group by (protected attributes).
        missing_values (list): Values to consider as missing (e.g., ['?', np.nan]).

    Returns:
        pd.DataFrame: A copy of the dataframe with imputed values.
    """
    df = df.copy()

    if isinstance(columns, str):
        columns = [columns]
    if isinstance(groupby_cols, str):
        groupby_cols = [groupby_cols]

    for col in columns:
        df[col] = df[col].replace(missing_values, np.nan)

        # Group by the specified protected attributes
        grouped = df.groupby(groupby_cols)

        for group_keys, _ in grouped:
            if not isinstance(group_keys, tuple):
                group_keys = (group_keys,)

            mask = np.ones(len(df), dtype=bool)
            for col_name, val in zip(groupby_cols, group_keys):
                mask &= df[col_name] == val

            na_mask = mask & df[col].isna()
            value_counts = df.loc[mask & df[col].notna(), col].value_counts(normalize=True)

            if value_counts.empty:
                continue

            sampled = np.random.choice(value_counts.index, size=na_mask.sum(), p=value_counts.values)
            df.loc[na_mask, col] = sampled

    return df


####1. For a single protected attribute (e.g., race):

In [None]:
df_filled = fill_na_groupwise_PMF(df, columns=['occupation'], groupby_cols='race')

####2. For intersectional protected attributes (e.g., race + sex):

In [None]:
df_filled = fill_na_groupwise_PMF(df, columns=['occupation', 'workclass'], groupby_cols=['race', 'sex'])

####3. With custom missing values:

In [None]:
df_filled = fill_na_groupwise_PMF(df, columns='native.country', groupby_cols='race', missing_values=['?', 'Unknown', np.nan])