# Folktables: Real‐World Bias Detection with MSD

## Configuration & Imports

In [1]:
import numpy as np
import pandas as pd

from folktables import ACSDataSource, ACSIncome
from humancompatible.detect import detect_bias, detect_bias_two_samples

#### Main parameters

In [2]:
state1 = "HI"
state2 = "ME"

#### Extra parameters

In [3]:
survey_year = "2018"
horizon = "1-Year"
data_root = "../data/folktables"

selected_columns = ['AGEP', 'MAR', 'POBP', 'SEX', 'RAC1P']
protected_attrs = ['AGEP', 'MAR', 'POBP', 'SEX', 'RAC1P']
# ['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P']
continuous_feats = []
feature_map = {}

seed = 42
n_samples = 1000
method = "MSD"
method_kwargs = {"time_limit": 120}  # 2 min per solve

## Load & Prepare Data via Folktables

In [4]:
def load_state_manual():
    """
    Attempts to download via folktables; if that fails, expects you to have
    manually downloaded & unzipped the two CSV zips into data_root/{year}/{horizon}/
    """
    ds = ACSDataSource(
        survey_year=survey_year,
        horizon=horizon,
        survey="person",
        root_dir=data_root,
    )
    try:
        # try folktables' automatic downloader
        raw = ds.get_data(states=[state1, state2], download=True)
    except Exception as e:
        print("\n⚠️  Automatic download failed:")
        print(f"    {e!r}\n")
        print("→ Please manually download these two files and unzip them under:")
        print(f"    {data_root}/{survey_year}/{horizon}/csv_p{state1.lower()}.zip")
        print(f"    {data_root}/{survey_year}/{horizon}/csv_p{state2.lower()}.zip")
        print("\nYou can get them from:")
        print(f"https://www2.census.gov/programs-surveys/acs/data/pums/{survey_year}/{horizon}/\n")
        # now try again, without download flag, so folktables will read from disk:
        raw = ds.get_data(states=[state1, state2], download=False)
    return raw

In [5]:
_POBP_STATE_CODE = {
    "AL":  1,  # Alabama
    "AK":  2,  # Alaska
    "AZ":  4,  # Arizona
    "AR":  5,  # Arkansas
    "CA":  6,  # California
    "CO":  8,  # Colorado
    "CT":  9,  # Connecticut
    "DE": 10,  # Delaware
    "DC": 11,  # District of Columbia
    "FL": 12,  # Florida
    "GA": 13,  # Georgia
    "HI": 15,  # Hawaii
    "ID": 16,  # Idaho
    "IL": 17,  # Illinois
    "IN": 18,  # Indiana
    "IA": 19,  # Iowa
    "KS": 20,  # Kansas
    "KY": 21,  # Kentucky
    "LA": 22,  # Louisiana
    "ME": 23,  # Maine
    "MD": 24,  # Maryland
    "MA": 25,  # Massachusetts
    "MI": 26,  # Michigan
    "MN": 27,  # Minnesota
    "MS": 28,  # Mississippi
    "MO": 29,  # Missouri
    "MT": 30,  # Montana
    "NE": 31,  # Nebraska
    "NV": 32,  # Nevada
    "NH": 33,  # New Hampshire
    "NJ": 34,  # New Jersey
    "NM": 35,  # New Mexico
    "NY": 36,  # New York
    "NC": 37,  # North Carolina
    "ND": 38,  # North Dakota
    "OH": 39,  # Ohio
    "OK": 40,  # Oklahoma
    "OR": 41,  # Oregon
    "PA": 42,  # Pennsylvania
    "RI": 44,  # Rhode Island
    "SC": 45,  # South Carolina
    "SD": 46,  # South Dakota
    "TN": 47,  # Tennessee
    "TX": 48,  # Texas
    "UT": 49,  # Utah
    "VT": 50,  # Vermont
    "VA": 51,  # Virginia
    "WA": 53,  # Washington
    "WV": 54,  # West Virginia
    "WI": 55,  # Wisconsin
}

def state_to_pobp_code(abbrev: str) -> int:
    """
    Turn a two-letter state code (e.g. 'CA') into the ACS POBP recode.
    Raises a KeyError if the state isn't in the map.
    """
    st = abbrev.strip().upper()
    try:
        return _POBP_STATE_CODE[st]
    except KeyError:
        raise KeyError(f"Unknown state abbreviation '{abbrev}'. Valid codes are: "
                       + ", ".join(sorted(_POBP_STATE_CODE.keys())))