In [None]:
"""
Analysis using cached slope dicts + NANDA for weights/level vars.
- Loads nandaslopes_past.pkl.gz / nandaslopes_future.pkl.gz
- Reads NANDA TSVs for TOTPOP10 + level vars
- Population-weighted aggregation (numeric-safe) incl. VAR2
- OLS/Logit with formulas; saves term-level details + counts as CSV
- IAT controls:
    * IAT_FILTER: restrict to one IAT or None for both
    * ADJUST_FOR_IAT: include C(iat) in models
    * IAT_INTERACTIONS: interact RHS with C(iat)
"""

from pathlib import Path
from collections import Counter, defaultdict
import gzip, pickle
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
from pathlib import Path
import re

# ========= CONFIG =========
DATA_DIR = Path(".")
RESULTS_DIR = ''

# slope dicts (your cached artifacts)
PAST_PKL   = DATA_DIR / "data/nandaslopes_past.pkl.gz"
FUTURE_PKL = DATA_DIR / "data/nandaslopes_future.pkl.gz"

# NANDA TSVs (used for TOTPOP10 + level vars only)
NANDA_1990_2010 = DATA_DIR / "data/38528-0001-Data.tsv"  # TOTPOP10, PPOV10, PUNEMP10, PPUBAS10
NANDA_2010_2015 = DATA_DIR / "data/38528-0002-Data.tsv"  # PPOV13_17, PUNEMP13_17, PPUBAS13_17
NANDA_2020      = DATA_DIR / "data/38528-0006-Data.tsv"  # PPOV, PUNEMP, PPUBAS (2020)

# Optional redlining file (contains 'HRI2010')
HRI_XLSX        = DATA_DIR / "data/Historic Redlining Indicator 2010.xlsx"  # must have 'GEOID10' or 'CensusTract' and 'HRI2010'

# Personomics / DDM
IAT1 = "changepreserve"
IAT2 = "futurepresent"
PERSONOMICS_1 = DATA_DIR / f"data/{IAT1}_pnomics_exp_new.csv"
PERSONOMICS_2 = DATA_DIR / f"data/{IAT2}_pnomics_exp.csv"
DDM_1 = DATA_DIR / "ddm_output/DMCfs_sesiat_output_{IAT1}.csv"
DDM_2 = DATA_DIR / f"ddm_output/DMCfs_sesiat_output_{IAT2}.csv"
ALPHA = 0.05

# ========= Utils =========
def zscore(s: pd.Series):
    s = pd.to_numeric(s, errors='coerce')
    m = s.mean()
    sd = s.std(ddof=0)
    return (s - m)/sd if (sd is not None and sd > 0) else pd.Series(np.nan, index=s.index)

def mode_scalar(s: pd.Series):
    m = s.mode(dropna=True)
    return m.iloc[0] if not m.empty else np.nan

def as_tract_str(s: pd.Series):
    return s.astype(str).str.split('.').str[0].str.zfill(11)

def is_binary_series(s: pd.Series):
    vals = pd.unique(pd.to_numeric(s, errors='coerce').dropna())
    return set(vals).issubset({0,1})

def load_pickle_dict(path: Path) -> dict:
    with gzip.open(path, "rb") as f:
        return pickle.load(f)

def dict_to_wide(d: dict, suffix: str) -> pd.DataFrame:
    # keys like "12345678901_PPOV_slope_1" → split once on first underscore
    s = pd.Series(d, name='value')
    df = s.rename_axis('key').reset_index()
    df[['CensusTract','metric']] = df['key'].astype(str).str.split('_', n=1, expand=True)
    df['CensusTract'] = as_tract_str(df['CensusTract'])
    wide = df.pivot(index='CensusTract', columns='metric', values='value').reset_index()
    wide.columns.name = None
    wide = wide.rename(columns=lambda c: c if c=='CensusTract' else f'{c}{suffix}')
    return wide

def pop_weighted_mean_factory(weights_series: pd.Series):
    # returns a function usable inside groupby.agg
    weights_series = pd.to_numeric(weights_series, errors='coerce')
    def _pwm(m: pd.Series):
        m = pd.to_numeric(m, errors='coerce')
        w = weights_series.loc[m.index]
        mask = m.notna() & w.notna() & (w > 0)
        if not mask.any():
            return np.nan
        m = m[mask].astype('float64')
        w = w[mask].astype('float64')
        sw = w.sum()
        if sw == 0:
            return np.nan
        return float((m * w).sum() / sw)
    return _pwm

def collect_effects(fit, meta, effect_details, sig_counts, alpha=ALPHA):
    params, pvals, bse, ci = fit.params, fit.pvalues, fit.bse, fit.conf_int()
    if hasattr(fit, "rsquared"):    metric_name, metric_val = "r2", float(fit.rsquared)
    elif hasattr(fit, "prsquared"): metric_name, metric_val = "prsquared", float(fit.prsquared)
    else:                           metric_name, metric_val = "aic", float(fit.aic)
    for term in params.index:
        if term.lower() in ("const","intercept"): continue
        effect_details[term].append({
            **meta,
            "term": term,
            "coef": float(params[term]),
            "se": float(bse[term]),
            "p": float(pvals[term]),
            "ci_low": float(ci.loc[term, 0]),
            "ci_high": float(ci.loc[term, 1]),
            metric_name: metric_val
        })
        if pvals[term] <= alpha:
            sig_counts[term] += 1

def recode_minority_type(race, eth):
    if race == 'White' and eth == 'Not Hispanic or Latino':
        return 'White'
    return 'BIPOC'

BINARY_ANY_COLUMNS = ["LILATracts_Vehicle"]  
def bin_any(m: pd.Series):
    v = pd.to_numeric(m, errors='coerce').dropna()
    if v.empty:
        return np.nan
    # if any tract is >= 0.5 (i.e., 1 for binary flags), mark participant as 1
    return float((v >= 0.5).any())

# ======== Plotting Utils ========
PRETTY_LABELS = {
    # IAT / DDM DVs
    "d": "IAT D-score",
    "peak_amplitude": "Peak automatic bias",
    "characteristic_time": "Time to peak bias",
    "alpha_int": "Baseline response caution",
    "alpha_dif": "Δ Caution in Incompatible",
    "tau": "Non-decision time",
    "mu_c_int": "Controlled drift rate",

    # Predictors / moderators
    "minority_type": "Minority status",
    "minority_type_num": "Minority status",
    "LILATracts_Vehicle": "Food insecurity (LILA)",
    "HRI2010": "Historic redlining indicator (2010)",

    # Level vars (edit as you prefer)
    "PPOV10": "Neighborhood Poverty Rate (2010)",
    "PUNEMP10": "Neighborhood Unemployment Rate (2010)",
    "PPUBAS10": "% Households w/ Public Assisted Income (2010)",
}

PRETTY_BASE_VARS = {
    "PPOV":   "Poverty",
    "PUNEMP": "Unemployment",
    "PPUBAS": "Public Assistance",
}

FIT_TEXT = {
    1: "linear fit",
    2: "quadratic fit",
}

COMPONENT_TEXT = {
    "constant":  "Intercept",
    "linear":    "Linear trend",
    "quadratic": "Quadratic trend",
}

BINARY_VALUE_LABELS = {
    "LILATracts_Vehicle": {0: "Food secure", 1: "Food insecure"},
    "minority_type": {0: "White", 1: "BIPOC"},
    "minority_type_num": {0: "White", 1: "BIPOC"},
}

EPOCH_LABELS = {"past": "1990–2010", "future": "2010–2015"}

def _strip_suffixes(name: str):
    """Remove trailing _z and capture epoch (_past/_future)."""
    n = name[:-2] if name.endswith("_z") else name
    epoch = None
    if n.endswith("_past"):
        n, epoch = n[:-5], "past"
    elif n.endswith("_future"):
        n, epoch = n[:-7], "future"
    return n, epoch

def pretty_label(name: str) -> str:
    """
    Understands names like:
      PPOV_constant_1_past, PPOV_linear_1_future, PPOV_quadratic_2_past
    → 'Poverty — Intercept (linear fit; 1990–2010)' etc.
    """
    root, epoch = _strip_suffixes(name)

    # 1) Exact override first (e.g., PPOV10, HRI2010)
    if root in PRETTY_LABELS:
        base = PRETTY_LABELS[root]
        return f"{base} ({EPOCH_LABELS.get(epoch, epoch)})" if epoch else base

    # 2) Parse <BASE>_(constant|linear|quadratic)_(1|2)
    m = re.match(r"^(?P<base>[A-Za-z0-9]+)_(?P<comp>constant|linear|quadratic)_(?P<deg>[12])$",
                 root, flags=re.IGNORECASE)
    if m:
        base = m.group("base")
        comp = m.group("comp").lower()
        deg  = int(m.group("deg"))

        base_label = PRETTY_LABELS.get(base) or PRETTY_BASE_VARS.get(base) or base.replace("_", " ").title()
        comp_text  = COMPONENT_TEXT[comp]
        fit_text   = FIT_TEXT[deg]

        tail = fit_text
        if epoch:
            tail = f"{fit_text}; {EPOCH_LABELS.get(epoch, epoch)}"
        return f"{base_label} — {comp_text} ({tail})"

    # 3) Fallback: base dictionaries or title-case + epoch if present
    base = PRETTY_LABELS.get(root) or PRETTY_BASE_VARS.get(root) or root.replace("_", " ").title()
    return f"{base} ({EPOCH_LABELS.get(epoch, epoch)})" if epoch else base

def pretty_level(varname: str, val) -> str:
    """Human-friendly level labels for binary vars."""
    try:
        v = int(val) if pd.notna(val) else val
    except Exception:
        v = val
    mapping = BINARY_VALUE_LABELS.get(varname)
    if mapping and v in mapping:
        return mapping[v]
    return str(v)

def _ensure_plot_dir(dv):
    outdir = RESULTS_DIR / "plots" / dv
    outdir.mkdir(parents=True, exist_ok=True)
    return outdir

def _var2_levels(g, var2_term):
    col = pd.to_numeric(g[var2_term], errors='coerce')
    uniq = set(pd.unique(col.dropna()))
    if uniq.issubset({0.0, 1.0}):
        return True, [0.0, 1.0]
    q25, q50, q75 = col.quantile([0.25, 0.50, 0.75])
    if np.isclose(q25, q50) or np.isclose(q50, q75):
        q1, q3 = col.quantile([0.25, 0.75])
        return False, [float(q1), float(q3)]
    return False, [float(q25), float(q50), float(q75)]

def _panel_mask(g, var2_term, is_binary, level, levels):
    col = pd.to_numeric(g[var2_term], errors='coerce')
    if is_binary:
        return col.round().eq(level)
    # continuous panels: low (<= q1) vs high (>= q3)
    if len(levels) == 3:
        q25, q50, q75 = levels
        if np.isclose(level, q25):
            return col <= q25
        elif np.isclose(level, q50):
            # middle band is (q25, q75)
            return (col > q25) & (col < q75)
        else:  # q75
            return col >= q75
    low, high = levels
    return (col <= low) if np.isclose(level, low) else (col >= high)

def _group_label_and_color(pred_name, val):
    # Default mapping for your main use-case
    if pred_name == 'minority_type':
        return ('White', COLOR_WHITE) if int(val) == 0 else ('BIPOC', COLOR_BIPOC)
    # fallback (binary generic)
    return (str(val), COLOR_BIPOC if int(val) == 1 else COLOR_WHITE)

def _prediction_grid(g, xvar, pred, var2_term, iat_in_model=False, covar_z_cols=None):
    x_vals = np.linspace(-2.0, 2.0, 50)  # standardized space
    is_bin, levels = _var2_levels(g, var2_term)
    # fix iat at its mode for cleaner comparisons (if it's in the model)
    iat_val = None
    if iat_in_model and 'iat' in g.columns:
        m = g['iat'].mode(dropna=True)
        iat_val = m.iloc[0] if not m.empty else (g['iat'].dropna().iloc[0] if g['iat'].dropna().size else None)

    rows = []
    for v2 in levels:
        for grp in [0, 1]:
            for x in x_vals:
                row = {xvar: float(x), pred: int(grp), var2_term: float(v2)}
                if iat_in_model and iat_val is not None:
                    row['iat'] = iat_val
                rows.append(row)
    grid = pd.DataFrame(rows)
    if covar_z_cols:
        for cz in covar_z_cols:
            if cz not in grid.columns:
                grid[cz] = 0.0
    if is_bin:
        grid['panel'] = [f"{pretty_label(var2_term)} = {pretty_level(var2_term, v)}"
                        for v in grid[var2_term]]
    else:
        if len(levels) == 3:
            q25, q50, q75 = levels
            def _band(v):
                if np.isclose(v, q25): return "low_25p"
                if np.isclose(v, q50): return "median_50p"
                return "high_75p"
            grid['panel'] = [f"{pretty_label(var2_term)}: {_band(v)}" for v in grid[var2_term]]
        else:
            low, high = levels
            grid['panel'] = [f"{pretty_label(var2_term)}: {'low_25p' if np.isclose(v, low) else 'high_75p'}"
                            for v in grid[var2_term]]
    # group labels/colors
    labels, colors = [], []
    for grp in grid[pred].tolist():
        lab, col = _group_label_and_color(pred, grp)
        labels.append(lab); colors.append(col)
    grid['group_label'] = labels
    grid['group_color'] = colors
    return grid, is_bin, levels

def _scatter_points(ax, data, xvar, yvar, pred):
    # Draw white (0) and bipoc (1) with specified colors
    for grp in [0, 1]:
        mask = pd.to_numeric(data[pred], errors='coerce').round().eq(grp)
        if not mask.any():
            continue
        lab, col = _group_label_and_color(pred, grp)
        ax.scatter(pd.to_numeric(data.loc[mask, xvar], errors='coerce'),
                   pd.to_numeric(data.loc[mask, yvar], errors='coerce'),
                   s=14, alpha=0.6, edgecolor='none', label=lab, c=col)

def _line_preds(ax, grid_subset, xvar):
    # Plot predicted lines by group label, preserving white/black mapping
    for lab, dfsub in grid_subset.groupby('group_label'):
        col = dfsub['group_color'].iloc[0]
        d = dfsub.sort_values(xvar)
        ax.plot(d[xvar].values, d['yhat'].values, linewidth=2.0, c=col, label='_nolegend_')

def plot_significant_ols_dv_on_comm(fit, g, dv, comm_z, pred, var2_term, iat_in_model=False, covar_z_cols=None):
    outdir = _ensure_plot_dir(dv)
    grid, is_bin, levels = _prediction_grid(g, comm_z, pred, var2_term, iat_in_model, covar_z_cols)
    grid['yhat'] = fit.predict(grid)

    # y is standardized DV
    yvar = f"{dv}_z"
    xvar = comm_z

    for level in levels:
        if is_bin:
            panel_label = f"{pretty_label(var2_term)} = {pretty_level(var2_term, level)}"
        else:
            if len(levels) == 3:
                q25, q50, q75 = levels
                if np.isclose(level, q25):
                    bucket = "low_25p"
                elif np.isclose(level, q50):
                    bucket = "median_50p"
                else:
                    bucket = "high_75p"
                panel_label = bucket
            else:
                low, high = levels
                bucket = "low_25p" if np.isclose(level, low) else "high_75p"
                panel_label = bucket
        mask_panel = _panel_mask(g, var2_term, is_bin, level, levels)
        g_panel = g.loc[mask_panel].copy()

        fig, ax = plt.subplots(figsize=(6,4))
        # scatter observed
        _scatter_points(ax, g_panel, xvar=xvar, yvar=yvar, pred=pred)
        # lines
        if is_bin:
            grid_subset = grid[ pd.to_numeric(grid[var2_term], errors='coerce').round().eq(level) ]
        else:
            grid_subset = grid[ np.isclose(pd.to_numeric(grid[var2_term], errors='coerce'), level) ]
        _line_preds(ax, grid_subset, xvar=xvar)

        base_comm = comm_z[:-2] if comm_z.endswith('_z') else comm_z
        panel_disp = (f"{pretty_label(var2_term)} = {pretty_level(var2_term, level)}"
                    if is_bin else f"{pretty_label(var2_term)}: {'Low' if np.isclose(level, levels[0]) else 'High'}")
        ax.set_xlabel(f"{pretty_label(base_comm)} (z)")
        ax.set_ylabel(f"{pretty_label(dv)} (z)")
        # ax.set_title(f"{pretty_label(dv)} ~ {pretty_label(base_comm)} × {pretty_label(pred)} | {panel_disp}")
        place_legend_outside(ax, where='right')
        fig.tight_layout()
        fig.savefig(outdir / f"OLS_dvOnComm_{dv}_{comm_z}_{pred}_{panel_label}.png", dpi=150)
        plt.close(fig)

        if ALSO_PLOT_COLLAPSED_PRED:
            # panel subset
            mask_panel = _panel_mask(g, var2_term, is_bin, level, levels)
            g_panel = g.loc[mask_panel].copy()

            # weights: proportion pred==1 in this panel (fallback .5 if NA)
            w1 = pd.to_numeric(g_panel[pred], errors='coerce').mean()
            if not np.isfinite(w1):
                w1 = 0.5
            w0 = 1.0 - w1

            # grid for this panel (same x, same var2 level), but do two preds then average
            if is_bin:
                grid_subset = grid[pd.to_numeric(grid[var2_term], errors='coerce').round().eq(level)].copy()
            else:
                grid_subset = grid[np.isclose(pd.to_numeric(grid[var2_term], errors='coerce'), level)].copy()

            g0 = grid_subset.copy(); g0[pred] = 0
            g1 = grid_subset.copy(); g1[pred] = 1
            y0 = fit.predict(g0)
            y1 = fit.predict(g1)
            grid_coll = g0[[xvar]].copy()
            grid_coll['yhat'] = w0 * y0 + w1 * y1

            # collapsed figure (points: all groups grey; one line)
            fig2, ax2 = plt.subplots(figsize=(6,4))
            ax2.scatter(pd.to_numeric(g_panel[xvar], errors='coerce'),
                        pd.to_numeric(g_panel[yvar], errors='coerce'),
                        s=12, edgecolor='none', c='#454545')
            d = grid_coll.sort_values(xvar)
            ax2.plot(d[xvar].values, d['yhat'].values, linewidth=2.5, linestyle='-', c='#454545')

            panel_label = (f"{pretty_label(var2_term)} = {pretty_level(var2_term, level)}" if is_bin else ( "low_25p" if (len(levels)!=3 and np.isclose(level, levels[0])) else ("high_75p" if (len(levels)!=3) else ("low_25p" if np.isclose(level, levels[0]) else ("median_50p" if np.isclose(level, levels[1]) else "high_75p")))))
            ax2.set_xlabel(f"{pretty_label(xvar)}")
            ax2.set_ylabel(f"{pretty_label(dv)} (z)")

            # filename: reuse your simple tag you already compute (panel_tag)
            fig2.tight_layout()
            fig2.savefig(outdir / f"OLS_dvOnComm_{dv}_{comm_z}_{pred}_{panel_label}_collapsedPred.png", dpi=150, bbox_inches="tight")
            plt.close(fig2)

def plot_significant_ols_comm_on_dv(fit, g, dv, comm, pred, var2_term, iat_in_model=False, covar_z_cols=None):
    outdir = _ensure_plot_dir(dv)
    xvar = f"{dv}_z"
    yvar = f"{comm}_z"
    if yvar not in g.columns:
        # ensure present
        g[yvar] = (pd.to_numeric(g[comm], errors='coerce') - pd.to_numeric(g[comm], errors='coerce').mean()) / pd.to_numeric(g[comm], errors='coerce').std(ddof=0)

    grid, is_bin, levels = _prediction_grid(g, xvar, pred, var2_term, iat_in_model, covar_z_cols)
    grid['yhat'] = fit.predict(grid)

    for level in levels:
        if is_bin:
            panel_label = f"{pretty_label(var2_term)} = {pretty_level(var2_term, level)}"
        else:
            if len(levels) == 3:
                q25, q50, q75 = levels
                if np.isclose(level, q25):
                    bucket = "low_25p"
                elif np.isclose(level, q50):
                    bucket = "median_50p"
                else:
                    bucket = "high_75p"
                panel_label = bucket
            else:
                low, high = levels
                bucket = "low_25p" if np.isclose(level, low) else "high_75p"
                panel_label = bucket
        mask_panel = _panel_mask(g, var2_term, is_bin, level, levels)
        g_panel = g.loc[mask_panel].copy()

        fig, ax = plt.subplots(figsize=(6,4))
        _scatter_points(ax, g_panel, xvar=xvar, yvar=yvar, pred=pred)
        if is_bin:
            grid_subset = grid[ pd.to_numeric(grid[var2_term], errors='coerce').round().eq(level) ]
        else:
            grid_subset = grid[ np.isclose(pd.to_numeric(grid[var2_term], errors='coerce'), level) ]
        _line_preds(ax, grid_subset, xvar=xvar)

        panel_disp = (f"{pretty_label(var2_term)} = {pretty_level(var2_term, level)}"
              if is_bin else f"{pretty_label(var2_term)}: {'Low' if np.isclose(level, levels[0]) else 'High'}")
        ax.set_xlabel(f"{pretty_label(dv)} (z)")
        ax.set_ylabel(f"P({pretty_label(comm)}=1) / observed (jittered)")
        # ax.set_title(f"P({pretty_label(comm)}=1) ~ {pretty_label(dv)} × {pretty_label(pred)} | {panel_disp}")
        place_legend_outside(ax, where='right')
        fig.tight_layout()
        fig.savefig(outdir / f"OLS_commOnDv_{comm}_{dv}_{pred}_{panel_label}.png", dpi=150)
        plt.close(fig)

        if ALSO_PLOT_COLLAPSED_PRED:
            mask_panel = _panel_mask(g, var2_term, is_bin, level, levels)
            g_panel = g.loc[mask_panel].copy()

            w1 = pd.to_numeric(g_panel[pred], errors='coerce').mean()
            if not np.isfinite(w1):
                w1 = 0.5
            w0 = 1.0 - w1

            # xvar = f"{dv}_z" ; yvar = f"{comm}_z" (already set above)
            if is_bin:
                grid_subset = grid[pd.to_numeric(grid[var2_term], errors='coerce').round().eq(level)].copy()
            else:
                grid_subset = grid[np.isclose(pd.to_numeric(grid[var2_term], errors='coerce'), level)].copy()

            g0 = grid_subset.copy(); g0[pred] = 0
            g1 = grid_subset.copy(); g1[pred] = 1
            y0 = fit.predict(g0)
            y1 = fit.predict(g1)
            grid_coll = g0[[xvar]].copy()
            grid_coll['yhat'] = w0 * y0 + w1 * y1

            fig2, ax2 = plt.subplots(figsize=(6,4))
            ax2.scatter(pd.to_numeric(g_panel[xvar], errors='coerce'),
                        pd.to_numeric(g_panel[yvar], errors='coerce'),
                        s=12, edgecolor='none', c="#454545")
            d = grid_coll.sort_values(xvar)
            ax2.plot(d[xvar].values, d['yhat'].values, linewidth=2.5, linestyle='-', c='#454545')

            panel_label = (f"{pretty_label(var2_term)} = {pretty_level(var2_term, level)}" if is_bin else ( "low_25p" if (len(levels)!=3 and np.isclose(level, levels[0])) else ("high_75p" if (len(levels)!=3) else ("low_25p" if np.isclose(level, levels[0]) else ("median_50p" if np.isclose(level, levels[1]) else "high_75p")))))
            ax2.set_xlabel(f"{pretty_label(dv)} (z)")
            ax2.set_ylabel(f"{pretty_label(comm)} (z)")

            fig2.tight_layout()
            fig2.savefig(outdir / f"OLS_commOnDv_{comm}_{dv}_{pred}_{panel_label}_collapsedPred.png", dpi=150, bbox_inches="tight")
            plt.close(fig2)


def plot_significant_logit_comm_on_dv(fit, g, dv, comm, pred, var2_term, iat_in_model=False):
    outdir = _ensure_plot_dir(dv)
    xvar = f"{dv}_z"
    yvar = comm  # binary 0/1

    grid, is_bin, levels = _prediction_grid(g, xvar, pred, var2_term, iat_in_model)
    grid['yhat'] = fit.predict(grid)

    rng = np.random.default_rng(42)
    for level in levels:
        panel_label = f"{var2_term}={int(level)}" if is_bin else ('low' if np.isclose(level, levels[0]) else 'high')
        mask_panel = _panel_mask(g, var2_term, is_bin, level, levels)
        g_panel = g.loc[mask_panel].copy()

        fig, ax = plt.subplots(figsize=(6,4))
        # scatter observed with a tiny vertical jitter to separate 0/1
        for grp in [0, 1]:
            mask = pd.to_numeric(g_panel[pred], errors='coerce').round().eq(grp)
            if not mask.any():
                continue
            xv = pd.to_numeric(g_panel.loc[mask, xvar], errors='coerce')
            yv = pd.to_numeric(g_panel.loc[mask, yvar], errors='coerce')
            yv = yv + rng.uniform(-0.03, 0.03, size=len(yv))
            lab, col = _group_label_and_color(pred, grp)
            ax.scatter(xv, yv, s=14, alpha=0.6, edgecolor='none', label=lab, c=col)

        # lines (predicted probabilities)
        _line_preds(ax, grid[grid['panel'].eq(panel_label)], xvar=xvar)

        ax.set_xlabel(f"{dv} (z)")
        ax.set_ylabel(f"P({comm}=1) / observed (jittered)")
        # ax.set_title(f"Pr({comm}=1) ~ {dv}_z × {pred} | {panel_label}")
        ax.set_ylim(-0.15, 1.15)
        place_legend_outside(ax, where='right')
        fig.tight_layout()
        fig.savefig(outdir / f"LOGIT_commOnDv_{comm}_{dv}_{pred}_{panel_label}.png", dpi=150)
        plt.close(fig)

def plot_baseline_2way_bars(fit, g, dv, pred, var2_term, covar_z_cols, iat_in_model=False):
    """
    Bar plot of the 2x2 interaction (var2 x pred) for the baseline Predict-IAT model.
    X-axis: var2 (0 = secure, 1 = insecure) with pretty labels.
    Bars: pred groups (White vs BIPOC). Shows model-estimated means with 95% CIs.
    """
    # Only for binary moderator
    if var2_term not in g.columns or not is_binary_series(g[var2_term]):
        return

    outdir = _ensure_plot_dir(dv)

    # Build a clean 2x2 grid (covars at mean=0; iat at mode if included)
    rows = []
    iat_val = None
    if iat_in_model and 'iat' in g.columns:
        m = g['iat'].mode(dropna=True)
        iat_val = m.iloc[0] if not m.empty else (g['iat'].dropna().iloc[0] if g['iat'].dropna().size else None)
    for v2 in [0, 1]:
        for grp in [0, 1]:
            row = {pred: grp, var2_term: v2}
            for cz in covar_z_cols:
                row[cz] = 0.0
            if iat_val is not None:
                row['iat'] = iat_val
            rows.append(row)

    grid = pd.DataFrame(rows)
    sf = fit.get_prediction(grid).summary_frame(alpha=ALPHA)
    grid['yhat'] = sf['mean'].astype(float)
    grid['cil']  = sf['mean_ci_lower'].astype(float)
    grid['cih']  = sf['mean_ci_upper'].astype(float)

    # Ns per cell
    n_table = g.groupby([var2_term, pred], dropna=False).size().rename('N').reset_index()
    grid = grid.merge(n_table, on=[var2_term, pred], how='left')

    # Labels / colors
    x_labels = [pretty_level(var2_term, 0), pretty_level(var2_term, 1)]
    group_labels = []
    group_colors = []
    for grp in grid[pred].tolist():
        lab, col = _group_label_and_color(pred, grp)
        group_labels.append(lab); group_colors.append(col)
    grid['group_label'] = group_labels
    grid['group_color'] = group_colors

    width = 0.38
    x_pos = np.array([0, 1])
    fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)

    ci_extents = []
    label_extents = []

    # Draw bars first (zorder=2), then error bars behind text (zorder=1)
    for grp, offset in zip([0, 1], [-width/2, width/2]):
        sub = grid[grid[pred] == grp].copy().sort_values(var2_term)
        centers = x_pos + offset
        y = sub['yhat'].to_numpy()
        cil = sub['cil'].to_numpy()
        cih = sub['cih'].to_numpy()
        ci_extents.extend(cil.tolist())
        ci_extents.extend(cih.tolist())
        yerr = np.vstack([y - cil, cih - y])  # shape (2, n)

        lab, col = _group_label_and_color(pred, grp)
        bars = ax.bar(centers, y, width=width, label=lab, color=col, alpha=0.95,
                      edgecolor='none', zorder=2)

        # Error bars (drawn separately so text can sit on top cleanly)
        ax.errorbar(centers, y, yerr=yerr, fmt='none', capsize=4, linewidth=1.4,
                    color='black', zorder=1)

        # n labels outside the CI with a light box for contrast
        for xi, yi, lo, hi, Ni in zip(centers, y, cil, cih, sub['N'].fillna(0).astype(int).tolist()):
            pad = 0.02 * max(1.0, np.nanmax(np.abs(grid[['cil','cih','yhat']].to_numpy())))  # dynamic pad
            if np.isnan(yi) or np.isnan(lo) or np.isnan(hi):
                continue
            # place above upper CI if positive; below lower CI if negative
            y_text = (hi + pad) if yi >= 0 else (lo - pad)
            label_extents.append(y_text)
            text_color = 'black'
            ax.text(xi, y_text, f"n={Ni}", ha='center', va='bottom' if yi>=0 else 'top',
                    fontsize=8, color=text_color, zorder=3, clip_on=False,
                    bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1.5))

    # Ensure ylim includes labels
    vals = np.array(ci_extents + label_extents, dtype=float)
    vals = vals[np.isfinite(vals)]
    if vals.size:
        span = vals.max() - vals.min()
        pad = 0.05 * (span if span > 0 else 1.0)
        ax.set_ylim(vals.min() - pad, vals.max() + pad)
    else:
        ax.margins(y=0.15)

    ax.set_xticks(x_pos)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel(f"{pretty_label(dv)} (z)")
    # ax.set_title(f"{pretty_label(dv)} by {pretty_label(var2_term)} × {pretty_label(pred)}")
    place_legend_outside(ax, where='right')
    # fig.tight_layout()
    fig.savefig(outdir / f"BAR_{dv}_{var2_term}_by_{pred}.png", dpi=150)
    plt.close(fig)

def place_legend_outside(ax, where='top', ncol=2):
    """
    Place legend outside the axes to avoid overlapping annotations.
    where: 'top' (centered above) or 'right' (right-hand side).
    """
    handles, labels = ax.get_legend_handles_labels()
    if where == 'right':
        return ax.legend(handles, labels,
                         loc='center left', bbox_to_anchor=(1.02, 0.5),
                         frameon=False, borderaxespad=0.)
    # default: top center, with multiple columns
    return ax.legend(handles, labels,
                     loc='upper center', bbox_to_anchor=(0.5, 1.14),
                     ncol=ncol, frameon=False)

# ======== Load sources ========
def load_slope_wides():
    past = load_pickle_dict(PAST_PKL)
    future = load_pickle_dict(FUTURE_PKL)
    return dict_to_wide(past, "_past"), dict_to_wide(future, "_future")

def read_nanda_table(path: Path) -> pd.DataFrame:
    return pd.read_csv(path, delimiter='\t')

def load_nanda_min():
    # Read full, then select available columns
    nd10   = read_nanda_table(NANDA_1990_2010)
    nd1015 = read_nanda_table(NANDA_2010_2015)
    nd20   = read_nanda_table(NANDA_2020)

    # Ensure key
    for df in (nd10, nd1015, nd20):
        if 'TRACT_FIPS10' not in df.columns:
            raise ValueError(f"'TRACT_FIPS10' missing in one of the NANDA files")
        df['CensusTract'] = as_tract_str(df['TRACT_FIPS10'])

    # Keep only needed columns if present
    keep10   = ['CensusTract','TOTPOP10','PPOV10','PUNEMP10','PPUBAS10']
    keep1015 = ['CensusTract','PPOV13_17','PUNEMP13_17','PPUBAS13_17']
    keep20   = ['CensusTract','PPOV','PUNEMP','PPUBAS']

    nd10   = nd10[[c for c in keep10   if c in nd10.columns]]
    nd1015 = nd1015[[c for c in keep1015 if c in nd1015.columns]]
    nd20   = nd20[[c for c in keep20   if c in nd20.columns]]

    # Numeric coercion for level vars and weights
    for df in (nd10, nd1015, nd20):
        for c in df.columns:
            if c not in ['CensusTract']:
                df[c] = pd.to_numeric(df[c], errors='coerce')

    return nd10, nd1015, nd20

def maybe_load_hri():
    if not HRI_XLSX.exists():
        return None
    hri = pd.read_excel(HRI_XLSX)
    # normalize column names
    hri.columns = [str(c).strip().replace(" ", "").upper() for c in hri.columns]
    # GEOID10 / CensusTract harmonization
    if 'GEOID10' in hri.columns:
        hri = hri.rename(columns={'GEOID10': 'CENSUSTRACT'})
    elif 'CENSUSTRACT' not in hri.columns:
        raise ValueError("HRI file must contain GEOID10 or CensusTract.")

    # bring names back to your code style
    ren = {c: c for c in hri.columns}
    ren['CENSUSTRACT'] = 'CensusTract'
    if 'HRI2010' in hri.columns:
        ren['HRI2010'] = 'HRI2010'
    # if your file has other relevant columns, they’ll pass through too
    hri = hri.rename(columns=ren)

    # keep tract id + any HRI-like columns
    keep = ['CensusTract'] + [c for c in hri.columns if c not in ['CensusTract']]
    hri = hri[keep].copy()
    # make tract id standard
    hri['CensusTract'] = as_tract_str(hri['CensusTract'])
    # coerce numerics for downstream
    for c in hri.columns:
        if c != 'CensusTract':
            hri[c] = pd.to_numeric(hri[c], errors='coerce')
    return hri

# ======== Prepare frame ========
def prepare_analysis_frame(past_wide, future_wide, nd10, nd1015, nd20, person, ddm):
    hri = maybe_load_hri()

    # Optionally restrict to one IAT up front
    if IAT_FILTER is not None:
        person = person[person['iat'] == IAT_FILTER].copy()

    # Unconditional HRI merge (if file exists) — merge into PERSON, which has CensusTract
    if hri is not None:
        person = person.merge(hri, on='CensusTract', how='left')

    # Build analysis df
    df = person.merge(ddm, on='PPT', how='inner')

    # Merge NANDA (weights + levels)
    if not {'CensusTract','TOTPOP10'}.issubset(nd10.columns):
        raise ValueError("NANDA 1990-2010 table must include CensusTract and TOTPOP10.")
    df = df.merge(nd10,   on='CensusTract', how='left')
    if not nd1015.empty:
        df = df.merge(nd1015, on='CensusTract', how='left')
    if not nd20.empty:
        df = df.merge(nd20,   on='CensusTract', how='left')

    # Merge slopes
    df = df.merge(past_wide,   on='CensusTract', how='left')
    df = df.merge(future_wide, on='CensusTract', how='left')

    # REMOVE the old “Optional HRI / moderator” block — HRI already merged above
    # (i.e., delete this from your file):
    # if hri is not None and VAR2 in hri.columns and VAR2 not in df.columns:
    #     df = df.merge(hri, on='CensusTract', how='left')

    # Demographic recodes
    df['minority_type'] = [recode_minority_type(r, e) for r, e in
                           zip(df.get('race', pd.Series([np.nan]*len(df))),
                               df.get('ethnicity', pd.Series([np.nan]*len(df))))]
    df['minority_type_num'] = df['minority_type'].map({'White': 0, 'BIPOC': 1})

    return df

# ======== Modeling ========
def run_models(df, dv_list, var2, pred, results_dir,
               require_hri_when_var2_not_hri,
               include_slopes_in_predict_iat,
               include_levels_in_predict_iat):
    
    # community features from slope dicts (wide) + selected level vars
    slope_cols = [c for c in df.columns if c.endswith('_past') or c.endswith('_future')]
    level_cols = [c for c in ['PPOV10','PUNEMP10','PPUBAS10'] if c in df.columns]
    # Separate the community feature sets we’ll use for each block
    past_slope_cols   = [c for c in slope_cols if c.endswith('_past')]
    future_slope_cols = [c for c in slope_cols if c.endswith('_future')]
    predict_iat_cols = []
    if include_slopes_in_predict_iat:
        predict_iat_cols.extend(past_slope_cols)
    if include_levels_in_predict_iat:
        predict_iat_cols.extend([c for c in level_cols if c != var2])
    predict_comm_cols = future_slope_cols
    comm_cols_selected = predict_iat_cols + predict_comm_cols
    covar_cols = [c for c in COVAR_VARS if INCLUDE_DEMOG_COVARS and c in df.columns]

    # Extra moderator(s) to carry/require beyond VAR2 (e.g., HRI2010)
    carry_along_cols = []
    require_nonmissing_cols = []

    if 'HRI2010' in df.columns:
        # carry HRI through groupby so we can filter on the aggregated value
        if (VAR2 != 'HRI2010'):
            carry_along_cols.append('HRI2010')
        if require_hri_when_var2_not_hri:
            require_nonmissing_cols.append('HRI2010')

    # Guards
    if 'zip_code' not in df.columns:
        raise ValueError("Expected 'zip_code' in personomics for grouping.")
    if var2 not in df.columns:
        raise ValueError(f"Expected '{var2}' in the merged frame. Add it to HRI file or change VAR2.")
    if 'TOTPOP10' not in df.columns:
        raise ValueError("TOTPOP10 missing after merge; needed for weighting.")

    all_sig_counts, all_effect_details = Counter(), defaultdict(list)

    for dv in dv_list:
        sig_counts, effect_details = Counter(), defaultdict(list)

        # ---- Build subset & coerce numerics BEFORE grouping ----
        keep = ['PPT','zip_code','iat', dv, pred, 'minority_type_num', var2, 'TOTPOP10'] \
            + comm_cols_selected + carry_along_cols + covar_cols
        sub = df[keep].copy()

        for c in (['TOTPOP10', var2] + comm_cols_selected + carry_along_cols):
            if c in sub.columns:
                sub[c] = pd.to_numeric(sub[c], errors='coerce')

        # Weighted mean (uses the closure to access sub['TOTPOP10'])
        pwm = pop_weighted_mean_factory(sub['TOTPOP10'])

        def agg_spec(vlist):
            spec = {
                dv: (dv, lambda s: np.nanmean(pd.to_numeric(s, errors='coerce'))),
                pred: (pred, lambda s: mode_scalar(s) if s.dtype == 'O'
                                else np.nanmean(pd.to_numeric(s, errors='coerce'))),
                'minority_type_num': ('minority_type_num', lambda s: mode_scalar(s)),
                'iat': ('iat', mode_scalar),
                var2: (var2, bin_any if (var2 in BINARY_ANY_COLUMNS and is_binary_series(sub[var2])) else pwm),
                'TOTPOP10': ('TOTPOP10', lambda s: np.nanmean(pd.to_numeric(s, errors='coerce')))
            }
            for v in vlist:
                spec[v] = (v, pwm)
            # weight HRI2010 (and any other carried moderators)
            for extra in carry_along_cols:
                spec[extra] = (extra, pwm)
            return spec

        g = sub.groupby(['PPT','zip_code']).agg(**agg_spec(comm_cols_selected + covar_cols)).reset_index()
        # Optional: enforce presence of HRI2010 even when VAR2 != HRI2010
        for col in require_nonmissing_cols:
            g = g[~g[col].isna()].copy()

        # z-score DV
        g[f'{dv}_z'] = zscore(g[dv])

        # z-score covariates
        covar_z = []
        for cvar in covar_cols:
            cz = f"{cvar}_z"
            g[cz] = zscore(g[cvar])
            covar_z.append(cz)

        # Build additive covariate RHS string
        covar_rhs = " + ".join(covar_z) if covar_z else ""

        # var2 term
        var2_term = var2 if is_binary_series(g[var2]) else f'{var2}_z'
        if var2_term.endswith('_z'):
            g[var2_term] = zscore(g[var2])

        # pred term
        if pred == 'minority_type':
            g['minority_type'] = g['minority_type_num']  # 0/1
        else:
            g[pred] = zscore(g[pred])

        # iat adjustment term
        iat_term_suffix = ""
        if ADJUST_FOR_IAT:
            if IAT_INTERACTIONS:
                # interact the whole RHS with C(iat)
                rhs_wrap = lambda rhs: f"({rhs})*C(iat)"
            else:
                rhs_wrap = lambda rhs: f"{rhs} + C(iat)"
        else:
            rhs_wrap = lambda rhs: rhs

        # -------- DV on community (OLS) Predicting IAT from past/present SES --------
        if predict_iat_cols:
            # Normal path: iterate over selected community columns
            for comm in predict_iat_cols:
                if '_constant' not in comm:
                    comm_z = f'{comm}_z'
                    g[comm_z] = zscore(g[comm])

                    rhs = f"{comm_z}*{pred} + {var2_term}*{pred} + {comm_z}*{var2_term}*{pred}"
                    if covar_rhs:
                        rhs = f"{rhs} + {covar_rhs}"
                    formula = f"{dv}_z ~ {rhs_wrap(rhs)}"

                    try:
                        fit = smf.ols(formula, data=g).fit(cov_type="HC3")
                        print(fit.summary())
                        meta = {
                            "IAT Parameter": dv, "Model": "Predict IAT",
                            "Community Variable": comm, "SES Moderator": var2,
                            "n": int(len(g)), "iat_filter": IAT_FILTER,
                            "adjust_for_iat": ADJUST_FOR_IAT, "iat_interactions": IAT_INTERACTIONS,
                            "include_slopes_in_predict_iat": include_slopes_in_predict_iat,
                            "include_levels_in_predict_iat": include_levels_in_predict_iat,
                        }
                        collect_effects(fit, meta, effect_details, sig_counts)

                        if PLOT_SIGNIFICANT and (fit.pvalues.drop(labels=[lab for lab in fit.pvalues.index if lab.lower() in ('const','intercept')]) <= ALPHA).any():
                            try:
                                plot_significant_ols_dv_on_comm(
                                    fit=fit, g=g, dv=dv, comm_z=comm_z, pred=pred,
                                    var2_term=var2_term, iat_in_model=ADJUST_FOR_IAT,
                                    covar_z_cols=covar_z
                                )
                            except Exception as e:
                                print(f"[plot dv_on_comm] {dv} ~ {comm}: {e}")
                    except Exception as e:
                        print(f"[OLS dv_on_comm] {dv} ~ {comm}: {e}")

        else:
            # NEW baseline path: no community columns selected — still run a model with ONLY pred and var2
            # Use their interaction (A*B expands to A + B + A:B); include covariates/IAT as configured.
            rhs = f"{pred}*{var2_term}"
            if covar_rhs:
                rhs = f"{rhs} + {covar_rhs}"
            formula = f"{dv}_z ~ {rhs_wrap(rhs)}"

            try:
                fit = smf.ols(formula, data=g).fit(cov_type="HC3")
                print(fit.summary())
                meta = {
                    "IAT Parameter": dv, "Model": "Predict IAT",
                    "Community Variable": "(none)", "SES Moderator": var2,
                    "n": int(len(g)), "iat_filter": IAT_FILTER,
                    "adjust_for_iat": ADJUST_FOR_IAT, "iat_interactions": IAT_INTERACTIONS,
                    "include_slopes_in_predict_iat": include_slopes_in_predict_iat,
                    "include_levels_in_predict_iat": include_levels_in_predict_iat,
                }
                collect_effects(fit, meta, effect_details, sig_counts)
                plot_baseline_2way_bars(
                    fit=fit,
                    g=g,
                    dv=dv,
                    pred=pred,
                    var2_term=var2_term,                # this will be the binary var (e.g., LILATracts_Vehicle)
                    covar_z_cols=[c for c in g.columns if c.endswith('_z') and c in covar_z],
                    iat_in_model=ADJUST_FOR_IAT
                )
            except Exception as e:
                print(f"[OLS dv_on_comm baseline] {dv}: {e}")


        # -------- Community on DV (OLS/LOGIT) Predicting future SES from IAT --------
        for comm in predict_comm_cols:
            if '_future' in comm and '_constant' not in comm:
                comm_is_binary = is_binary_series(g[comm])
                lhs = comm if comm_is_binary else f'{comm}_z'
                if not comm_is_binary:
                    g[lhs] = zscore(g[comm])
                # rhs = f"{dv}_z*{pred} + {var2_term}*{pred} + {dv}_z*{var2_term}*{pred}"
                rhs = f"{dv}_z*{pred}"
                if covar_rhs:
                    rhs = f"{rhs} + {covar_rhs}"
                formula = f"{lhs} ~ {rhs_wrap(rhs)}"
                try:
                    if comm_is_binary:
                        fit = smf.logit(formula, data=g).fit(disp=0)
                        print(fit.summary())
                    else:
                        fit = smf.ols(formula, data=g).fit(cov_type="HC3")
                        print(fit.summary())
                    meta = {"IAT Parameter": dv, "Model": "Predict Community", "Community Variable": comm, "SES Moderator": var2,
                            "n": int(len(g)), "iat_filter": IAT_FILTER, "adjust_for_iat": ADJUST_FOR_IAT,
                            "iat_interactions": IAT_INTERACTIONS}
                    collect_effects(fit, meta, effect_details, sig_counts)
                    if PLOT_SIGNIFICANT and (fit.pvalues.drop(labels=[lab for lab in fit.pvalues.index if lab.lower() in ('const','intercept')]) <= ALPHA).any():
                        try:
                            if is_binary_series(g[comm]):
                                plot_significant_logit_comm_on_dv(
                                    fit=fit, g=g, dv=dv, comm=comm, pred=pred, var2_term=var2_term, iat_in_model=ADJUST_FOR_IAT
                                )
                            else:
                                plot_significant_ols_comm_on_dv(
                                    fit=fit, g=g, dv=dv, comm=comm, pred=pred, var2_term=var2_term, iat_in_model=ADJUST_FOR_IAT, covar_z_cols=covar_z
                                )
                        except Exception as e:
                            print(f"[plot comm_on_dv] {comm} ~ {dv}: {e}")
                except Exception as e:
                    print(f"[{'LOGIT' if comm_is_binary else 'OLS'} comm_on_dv] {comm} ~ {dv}: {e}")

        # ---- Save per-DV outputs (CSV) ----
        rows = [rec for term, items in effect_details.items() for rec in items]
        if rows:
            pd.DataFrame(rows).to_csv(RESULTS_DIR / f"effect_details_{dv}.csv", index=False)
        pd.DataFrame(sig_counts.most_common(), columns=["term","count"])\
          .to_csv(RESULTS_DIR / f"effect_counts_{dv}.csv", index=False)

        all_sig_counts.update(sig_counts)
        for k, v in effect_details.items():
            all_effect_details[k].extend(v)

        print(f"\n=== {dv} ===")
        print("Top terms:", sig_counts.most_common(10))

    # ---- Save global summaries (CSV) ----
    pd.DataFrame(all_sig_counts.most_common(), columns=["term","count"]).to_csv(RESULTS_DIR / "effect_counts_ALL.csv", index=False)
    all_rows = [rec for term, items in all_effect_details.items() for rec in items]
    if all_rows:
        pd.DataFrame(all_rows).to_csv(RESULTS_DIR / "effect_details_ALL.csv", index=False)


In [5]:
DATA_DIR = Path(".")

# slope dicts (your cached artifacts)
PAST_PKL   = DATA_DIR / "data/nandaslopes_past.pkl.gz"
FUTURE_PKL = DATA_DIR / "data/nandaslopes_future.pkl.gz"

# NANDA TSVs (used for TOTPOP10 + level vars only)
NANDA_1990_2010 = DATA_DIR / "data/38528-0001-Data.tsv"  # TOTPOP10, PPOV10, PUNEMP10, PPUBAS10
NANDA_2010_2015 = DATA_DIR / "data/38528-0002-Data.tsv"  # PPOV13_17, PUNEMP13_17, PPUBAS13_17
NANDA_2020      = DATA_DIR / "data/38528-0006-Data.tsv"  # PPOV, PUNEMP, PPUBAS (2020)

def load_pickle_dict(path: Path) -> dict:
    with gzip.open(path, "rb") as f:
        return pickle.load(f)

def dict_to_wide(d: dict, suffix: str) -> pd.DataFrame:
    # keys like "12345678901_PPOV_slope_1" → split once on first underscore
    s = pd.Series(d, name='value')
    df = s.rename_axis('key').reset_index()
    df[['CensusTract','metric']] = df['key'].astype(str).str.split('_', n=1, expand=True)
    df['CensusTract'] = as_tract_str(df['CensusTract'])
    wide = df.pivot(index='CensusTract', columns='metric', values='value').reset_index()
    wide.columns.name = None
    wide = wide.rename(columns=lambda c: c if c=='CensusTract' else f'{c}{suffix}')
    return wide

def load_slope_wides():
    past = load_pickle_dict(PAST_PKL)
    future = load_pickle_dict(FUTURE_PKL)
    return dict_to_wide(past, "_past"), dict_to_wide(future, "_future")

def read_nanda_table(path: Path) -> pd.DataFrame:
    return pd.read_csv(path, delimiter='\t')

def load_nanda_min():
    # Read full, then select available columns
    nd10   = read_nanda_table(NANDA_1990_2010)
    nd1015 = read_nanda_table(NANDA_2010_2015)
    nd20   = read_nanda_table(NANDA_2020)

    # Ensure key
    for df in (nd10, nd1015, nd20):
        if 'TRACT_FIPS10' not in df.columns:
            raise ValueError(f"'TRACT_FIPS10' missing in one of the NANDA files")
        df['CensusTract'] = as_tract_str(df['TRACT_FIPS10'])

    # Keep only needed columns if present
    keep10   = ['CensusTract','TOTPOP10','PPOV10','PUNEMP10','PPUBAS10']
    keep1015 = ['CensusTract','PPOV13_17','PUNEMP13_17','PPUBAS13_17']
    keep20   = ['CensusTract','PPOV','PUNEMP','PPUBAS']

    nd10   = nd10[[c for c in keep10   if c in nd10.columns]]
    nd1015 = nd1015[[c for c in keep1015 if c in nd1015.columns]]
    nd20   = nd20[[c for c in keep20   if c in nd20.columns]]

    # Numeric coercion for level vars and weights
    for df in (nd10, nd1015, nd20):
        for c in df.columns:
            if c not in ['CensusTract']:
                df[c] = pd.to_numeric(df[c], errors='coerce')

    return nd10, nd1015, nd20

def load_personomics_and_ddm():
    p1 = pd.read_csv(PERSONOMICS_1); p1['iat'] = IAT1
    p2 = pd.read_csv(PERSONOMICS_2); p2['iat'] = IAT2
    p3 = pd.read_csv(PERSONOMICS_3); p3['iat'] = IAT3
    maxppt1 = p1['PPT'].max()
    p2['PPT'] = p2['PPT'] + maxppt1
    maxppt2 = p2['PPT'].max()
    p3['PPT'] = p3['PPT'] + maxppt2
    person = pd.concat([p1, p2, p3], ignore_index=True)
    ddm1 = pd.read_csv(DDM_1).rename(columns={"id":"PPT"})
    ddm2 = pd.read_csv(DDM_2); ddm2['id'] = ddm2['id'] + maxppt1; ddm2 = ddm2.rename(columns={"id":"PPT"})
    ddm3 = pd.read_csv(DDM_3); ddm3['id'] = ddm3['id'] + maxppt2; ddm3 = ddm3.rename(columns={"id":"PPT"})
    ddm = pd.concat([ddm1, ddm2, ddm3], ignore_index=True)
    # Ensure padded tract ids in person
    person['CensusTract'] = as_tract_str(person['CensusTract'])
    return person, ddm

def as_tract_str(s: pd.Series):
    return s.astype(str).str.split('.').str[0].str.zfill(11)

past_wide, future_wide = load_slope_wides()
nd10, nd1015, nd20 = load_nanda_min()

  return pd.read_csv(path, delimiter='\t')
  return pd.read_csv(path, delimiter='\t')


In [8]:
IAT1 = "changepreserve"
IAT2 = "progressrestore"
IAT3 = "futurepresent"
PERSONOMICS_1 = DATA_DIR / f"data/{IAT1}_pnomics_exp.csv"
PERSONOMICS_2 = DATA_DIR / f"data/{IAT2}_pnomics_exp.csv"
PERSONOMICS_3 = DATA_DIR / f"data/{IAT3}_pnomics_exp.csv"
DDM_1 = DATA_DIR / f"ddm_output/DMCfs_sesiat_output_{IAT1}.csv"
DDM_2 = DATA_DIR / f"ddm_output/DMCfs_sesiat_output_{IAT2}.csv"
DDM_3 = DATA_DIR / f"ddm_output/DMCfs_sesiat_output_{IAT3}.csv"

person, ddm = load_personomics_and_ddm()

# Moderator (present in HRI_XLSX or NANDA-merged data)
VAR2 = "HRI2010"               # or "LILATracts_Vehicle", "HRI2010" etc.
REQUIRE_HRI_WHEN_VAR2_NOT_HRI = True   # set True to enforce the restriction
PRED = "minority_type"        # will be 0/1 after recode

INCLUDE_DEMOG_COVARS = False
COVAR_VARS = ['education', 'soc_income']

INCLUDE_SLOPES_IN_PREDICT_IAT = True   # past *_past columns (from slope dicts)
INCLUDE_LEVELS_IN_PREDICT_IAT = False   # level columns (e.g., PPOV10, PUNEMP10, PPUBAS10)

# IAT controls
IAT_FILTER = None               # set to IAT1 or IAT2 string to restrict; None = combine both
ADJUST_FOR_IAT = True           # include + C(iat) in formulas
IAT_INTERACTIONS = True        # if True, interact RHS with C(iat)

DV_LIST = ['d','peak_amplitude','alpha_int','alpha_dif','tau','mu_c_int','characteristic_time']
ALPHA = 0.05
RESULTS_TAGS = [
    f"var2={VAR2}",
    f"iat={IAT_FILTER if IAT_FILTER is not None else 'both'}",
    f"hriReq={int(REQUIRE_HRI_WHEN_VAR2_NOT_HRI)}",
    f"covars={int(INCLUDE_DEMOG_COVARS)}",
]
RESULTS_FOLDER = "model_results__" + "__".join(RESULTS_TAGS)
RESULTS_DIR = DATA_DIR / RESULTS_FOLDER
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Plotting controls
PLOT_SIGNIFICANT = True
COLOR_WHITE = "#888888"  # grey
COLOR_BIPOC = "#000000"  # black
PLOT_DIR = RESULTS_DIR / "plots"
PLOT_DIR.mkdir(parents=True, exist_ok=True)

ALSO_PLOT_COLLAPSED_PRED = False

# ========= RUN =========
if __name__ == "__main__":
    df = prepare_analysis_frame(past_wide, future_wide, nd10, nd1015, nd20, person, ddm)
    run_models(
        df,
        DV_LIST,
        var2=VAR2,
        pred=PRED,
        results_dir=RESULTS_DIR,
        require_hri_when_var2_not_hri=REQUIRE_HRI_WHEN_VAR2_NOT_HRI,
        include_slopes_in_predict_iat=INCLUDE_SLOPES_IN_PREDICT_IAT,
        include_levels_in_predict_iat=INCLUDE_LEVELS_IN_PREDICT_IAT,
    )

                            OLS Regression Results                            
Dep. Variable:                    d_z   R-squared:                       0.157
Model:                            OLS   Adj. R-squared:                  0.101
Method:                 Least Squares   F-statistic:                     3.342
Date:                Wed, 26 Nov 2025   Prob (F-statistic):           7.18e-07
Time:                        22:14:42   Log-Likelihood:                -493.37
No. Observations:                 370   AIC:                             1035.
Df Residuals:                     346   BIC:                             1129.
Df Model:                          23                                         
Covariance Type:                  HC3                                         
                                                                             coef    std err          z      P>|z|      [0.025      0.975]
-----------------------------------------------------------------------