In [1]:
import os
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import fastparquet
from glob import glob
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt

In [2]:
# -----------------------------
# Utility functions
# -----------------------------

def safe_read_parquet(path, columns=None):
    """Read parquet safely with fastparquet."""
    try:
        return pd.read_parquet(path, engine="fastparquet", columns=columns)
    except Exception as e:
        print(f"Error reading {os.path.basename(path)}: {e}")
        return pd.DataFrame()

def quantiles(x, qs=[0.025, 0.975]):
    """Helper to compute multiple quantiles at once."""
    return np.quantile(x, qs, axis=1)

In [5]:

msm_files = sorted(glob("/Users/meibinchen/Documents/GitHub/EEE/msm_draws/msm_blk_*.parquet"))
print(f"{len(msm_files)} MSM files found")

df_full = pd.concat([pd.read_parquet(f, engine="fastparquet") for f in msm_files], axis=1)

q025, q975 = quantiles(df_full.values)
df_summary = pd.DataFrame({
    "mean": df_full.mean(axis=1),
    "median": df_full.median(axis=1),
    "2.5%": q025,
    "97.5%": q975,
})

# df_summary.to_csv('msm_county_summary.csv', index=False)

20 MSM files found


In [6]:
# -----------------------------
# Male draws (county-level)
# -----------------------------

male_files = sorted(glob("/Users/meibinchen/Documents/GitHub/EEE/male_draws/male_blk_*.parquet"))
print(f"{len(male_files)} Male files found")

df_males = pd.concat([pd.read_parquet(f, engine="fastparquet") for f in male_files], axis=1)

# MSM rate = MSM / Males
df_rate = df_full.values / df_males.values
q025, q975 = quantiles(df_rate)
df_rate_summary = pd.DataFrame({
    "mean": df_rate.mean(axis=1),
    "median": np.median(df_rate, axis=1),
    "2.5%": q025,
    "97.5%": q975,
})
# df_rate_summary.to_csv('msm_rate_county_summary.csv', index=False)

20 Male files found


In [3]:
# -----------------------------
# GEOID mapping
# -----------------------------

nchscodes = pd.read_csv(
    "/Users/meibinchen/Library/CloudStorage/OneDrive-JohnsHopkins/EEE HIV Stigma/EEE Data/Other/GSS and NCHS Data for Marginal Dist/GEOID.csv",
    dtype={'FIPS': str}
)
nchscodes['county_index'] = nchscodes.index + 1

state_geoid = {
    state: nchscodes.loc[nchscodes['ST_ABBREV'] == state, 'county_index'].values
    for state in nchscodes['ST_ABBREV'].unique()
}

states = nchscodes['ST_ABBREV'].values
unique_states = np.unique(states)

In [None]:
sim_vals = df_full.values
sim_male_vals = df_males.values

state_sim_matrix = np.zeros((len(unique_states), sim_vals.shape[1]), dtype=np.float32)
state_sim_male_matrix = np.zeros_like(state_sim_matrix)

for i, state in enumerate(unique_states):
    mask = (states == state)
    state_sim_matrix[i, :] = np.nansum(sim_vals[mask], axis=0)
    state_sim_male_matrix[i, :] = np.nansum(sim_male_vals[mask], axis=0)

state_sim_rate_matrix = state_sim_matrix / state_sim_male_matrix

# Summary: MSM counts by state
q025, q975 = quantiles(state_sim_matrix)
state_sim_summary = pd.DataFrame({
    'mean': state_sim_matrix.mean(axis=1),
    'median': np.median(state_sim_matrix, axis=1),
    '2.5%': q025,
    '97.5%': q975,
    'state': unique_states
})

# Summary: MSM rate by state
q025, q975 = quantiles(state_sim_rate_matrix)
state_sim_rate_summary = pd.DataFrame({
    'mean': state_sim_rate_matrix.mean(axis=1),
    'median': np.median(state_sim_rate_matrix, axis=1),
    '2.5%': q025,
    '97.5%': q975,
    'state': unique_states
})
# state_sim_summary.to_csv('msm_state_summary.csv', index=False)
# state_sim_rate_summary.to_csv('msm_rate_state_summary.csv', index=False)

In [None]:
# -----------------------------
# MSM by age group × county
# -----------------------------

chunk_paths = sorted(glob("/Users/meibinchen/Documents/GitHub/EEE/msm_draws/adj_msm_age_blk_*.parquet"))
print(f"{len(chunk_paths)} age-county chunk files found")
n_counties = 3144
chunk_size = 100
results = []

for start in range(1, n_counties + 1, chunk_size):
    end = min(start + chunk_size - 1, n_counties)
    county_range_set = set(range(start, end + 1))
    print(f"Processing counties {start}–{end}")

    with ThreadPoolExecutor(max_workers=8) as executor:
        dfs = list(executor.map(lambda p: safe_read_parquet(p, ["age_group", "county_index", "msm_count"]), chunk_paths))
    dfs = [df[df["county_index"].isin(county_range_set)] for df in dfs if not df.empty]
    if not dfs:
        continue

    df_chunk = pd.concat(dfs, ignore_index=True)

    summary = (
        df_chunk.groupby(["age_group", "county_index"])["msm_count"]
        .agg(mean="mean", median="median",
             q025=lambda x: x.quantile(0.025),
             q975=lambda x: x.quantile(0.975))
        .reset_index()
    )
    results.append(summary)

final_summary = pd.concat(results, ignore_index=True)
# final_summary.to_csv('adj_msm_age_county_summary.csv', index=False)

In [None]:
# -----------------------------
# MSM by state × age group
# -----------------------------

msm_paths = sorted(glob("/Users/meibinchen/Documents/GitHub/EEE/msm_draws/adj_msm_age_blk_*.parquet"))
male_paths = sorted(glob("/Users/meibinchen/Documents/GitHub/EEE/male_draws/adj_male_age_blk*.parquet"))

results_quantiles, results_sums, male_sums = [], [], []

for state in unique_states:
    states_range = state_geoid[state]
    print(f"Processing state {state}")

    # MSM
    with ThreadPoolExecutor(max_workers=8) as executor:
        dfs = list(executor.map(lambda p: safe_read_parquet(p, ["age_group", "county_index", "sim_index", "msm_count"]), msm_paths))
    dfs = [df[df["county_index"].isin(states_range)] for df in dfs if not df.empty]
    if dfs:
        df_chunk = pd.concat(dfs, ignore_index=True)
        df_chunk.loc[df_chunk['msm_count'] < 0, 'msm_count'] = 0
        df_chunk['sim_index'] = sorted(np.tile(np.arange(1, 100000 + 1), reps=len(states_range) * 5))

        # Sum per age_group × sim_index across counties
        sum_by_age_sim = (
            df_chunk
            .groupby(['age_group', 'sim_index'], observed=True)['msm_count']
            .sum()
            .reset_index()
        )

        # Now summarize across simulations per age_group
        summary_quantiles = (
            sum_by_age_sim
            .groupby('age_group', observed=True)['msm_count']
            .agg(
                mean='mean',
                median='median',
                q025=lambda s: s.quantile(0.025),
                q975=lambda s: s.quantile(0.975)
            )
            .reset_index()
        )
        summary_quantiles['state'] = state

        results_quantiles.append(summary_quantiles)

        summary_sums = df_chunk.groupby(['age_group', 'sim_index'])['msm_count'].sum().unstack().T
        results_sums.append(summary_sums)

    # Male
    with ThreadPoolExecutor(max_workers=8) as executor:
        dfs = list(executor.map(lambda p: safe_read_parquet(p, ["age_group", "county_index", "sim_index", "male_pop"]), male_paths))
    dfs = [df[df["county_index"].isin(states_range)] for df in dfs if not df.empty]
    if dfs:
        df_chunk = pd.concat(dfs, ignore_index=True)
        df_chunk.loc[df_chunk['male_pop'] < 0, 'male_pop'] = 0
        df_chunk['sim_index'] = sorted(np.tile(np.arange(1, 100000 + 1), reps=len(states_range) * 5))

        summary_sums = df_chunk.groupby(['age_group', 'sim_index'])['male_pop'].sum().unstack().T
        male_sums.append(summary_sums)

# Combine summaries
msm_age_states_summary = pd.concat(results_quantiles, ignore_index=True)
msm_age_states_summary.rename(columns={"<lambda_0>": "q025", "<lambda_1>": "q975"}, inplace=True)
msm_age_states_summary['summary_stat'] = np.tile(['mean', 'median', 'q025', 'q975'], reps=len(unique_states))
# msm_age_states_summary.to_csv('msm_state_age_summary.csv', index=False)

# MSM rates by state × age
state_sim_msm_age = pd.concat(results_sums)
state_sim_male_age = pd.concat(male_sums)

state_sim_rate_age = state_sim_msm_age / state_sim_male_age
state_sim_rate_age['state'] = np.repeat(unique_states, repeats=100000)  # careful: assumes 100k sims per state

state_sim_rate_age = state_sim_rate_age.melt(id_vars='state')
state_sim_rate_age_summary = (
    state_sim_rate_age.groupby(['state', 'age_group'])
    .agg(mean="mean", median="median",
         q025=lambda x: x.quantile(0.025),
         q975=lambda x: x.quantile(0.975))
    .reset_index()
)
state_sim_rate_age_summary.columns = ['state', 'age_group', 'mean', 'median', 'q025', 'q975']
# state_sim_rate_age_summary.to_csv('msm_rate_state_age_summary.csv', index=False)

Processing state AK
Processing state AL
Processing state AR
