
# ARGOX — End‑to‑End (FINAL)  
**Date:** 2025-11-12

A streamlined notebook that runs **end‑to‑end** using **existing local cache files** (no external pulls).  
It combines the robust loaders/plotting from **Fixed15** with early‑pipeline sanity checks from **Fixed10**.



## 0) Configuration
Leave entries as `None` to auto-detect; set a path to override.


In [None]:

MOBILITY_CSV        = None
RT_CSV              = None
MOBILITY_VALUE_COL  = None

ILI_CACHE_FILE      = None
GT_CACHE_DIR        = None
HUMIDITY_DAILY_CSV  = None

OUT_DIR = './outputs/quicklooks'



## 1) Imports & Utilities


In [None]:

import os, glob, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

os.makedirs(OUT_DIR, exist_ok=True)

def _find_first_existing(candidates, search_dirs=(".",)):

    for sd in search_dirs:
        for c in candidates:
            cand = os.path.join(sd, c)
            if os.path.exists(cand):
                return cand
            hits = glob.glob(os.path.join(sd, '**', c), recursive=True)
            if hits:
                return hits[0]
    return None


def _case_insensitive_glob(pattern):
    def _ci(p):
        out=''
        for ch in p:
            if ch.isalpha(): out+=f'[{ch.lower()}{ch.upper()}]'
            else: out+=ch
        return out
    return glob.glob(_ci(pattern), recursive=True)


def preview(df, name, n=5):
    if df is None:
        print(f'[skip] {name}: not provided/not found.')
        return
    print(f'[preview] {name}: shape={getattr(df, "shape", None)}')
    display(df.head(n))


def week_to_season(d):
    y = d.year
    start = y-1 if d.month < 7 else y
    return f"{start}-{str(start+1)[-2:]}"



## 2) ILI/GT/Humidity sanity (optional)


In [None]:

# ILI
if ILI_CACHE_FILE is None:
    ILI_CACHE_FILE = _find_first_existing(['ili_state_all.csv','ili_allstates.csv','ili_state_weekly.csv','ili_states_weekly.csv'],
                                          search_dirs=('./cache','./cache/ili_cache_states','./cache/ili_cache'))
ili_df=None
if ILI_CACHE_FILE and os.path.exists(ILI_CACHE_FILE):
    try:
        ili_df = pd.read_csv(ILI_CACHE_FILE)
        preview(ili_df,'ILI cache')
    except Exception as e:
        print('[warn] ILI load failed:', e)
else:
    print('[info] No ILI cache file found (OK).')

# GT
if GT_CACHE_DIR is None:
    for cand in ['./gt_cache_states','./gt_cache','./cache/gt_cache_states']:
        if os.path.isdir(cand):
            GT_CACHE_DIR = cand; break
if GT_CACHE_DIR and os.path.isdir(GT_CACHE_DIR):
    hits = glob.glob(os.path.join(GT_CACHE_DIR, '*.csv')) or glob.glob(os.path.join(GT_CACHE_DIR, '**','*.csv'), recursive=True)
    hits = hits[:3]
    print(f'[info] GT cache dir: {GT_CACHE_DIR} (showing up to 3)')
    for h in hits:
        try:
            dfh = pd.read_csv(h, nrows=5)
            print('  -', os.path.relpath(h))
            display(dfh.head(3))
        except Exception as e:
            print('  - (read error)', h, e)
else:
    print('[info] No GT cache dir found (OK).')

# Humidity
if HUMIDITY_DAILY_CSV and os.path.exists(HUMIDITY_DAILY_CSV):
    try:
        hum = pd.read_csv(HUMIDITY_DAILY_CSV)
        preview(hum, 'Humidity (daily)')
    except Exception as e:
        print('[warn] Humidity load failed:', e)
else:
    print('[info] No humidity daily CSV specified (preview skipped).')



## 3) Mobility + R(t) loaders (auto-detect with fallbacks)


In [None]:

# Mobility
if MOBILITY_CSV is None:
    MOBILITY_CSV = _find_first_existing(['mobility_state_weekly.csv','mobility_state_weekly_pct.csv','mobility_pct_weekly.csv'],
                                        search_dirs=('./cache','./outputs','./'))
print('Mobility path:', MOBILITY_CSV)
assert MOBILITY_CSV is not None, 'Mobility CSV not found. Set MOBILITY_CSV.'

# R(t)
if RT_CSV is None:
    RT_CSV = _find_first_existing(['rt_state_weekly.csv','rt_weekly_allstates.csv','rt_allstates_weekly.csv','Rt_weekly_allstates.csv'],
                                  search_dirs=('./cache','./outputs','./'))
if RT_CSV is None:
    cand=None
    for pat in ['./outputs/fig2/Fig2_weekly_scatter_inputs.csv','./**/Fig2_weekly_scatter_inputs.csv']:
        hits=_case_insensitive_glob(pat)
        if hits: cand=hits[0]; break
    if cand is not None:
        print('[fallback] Found Fig2_weekly_scatter_inputs at:', cand)
        df_src=pd.read_csv(cand)
        df_src.columns=[c.strip().lower() for c in df_src.columns]
        def _pick_state_col(cols):
            for n in ['state','state_abbr','st','state_code','stateid','state_id']:
                if n in cols: return n
            return None
        def _pick_date_col(df):
            hints=[c for c in df.columns if any(k in c for k in ['date','week','ending'])]
            for c in hints:
                try:
                    pd.to_datetime(df[c]); return c
                except Exception: pass
            for c in df.columns:
                s=pd.to_datetime(df[c], errors='coerce')
                if s.notna().mean()>0.8: return c
            return None
        def _pick_rt_col(df):
            for n in ['rt','r_t','r','re','r_eff','r_effective','rt_weekly','rt_median']:
                if n in df.columns: return n
            num_cols=[c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
            best=None; best_score=-1
            for c in num_cols:
                s=pd.to_numeric(df[c], errors='coerce'); s=s[np.isfinite(s)]
                if s.empty: continue
                med=float(np.nanmedian(s)); std=float(np.nanstd(s))
                score=0
                if 0.5<=med<=2.0: score+=2
                if 0.01<=std<=0.5: score+=1
                if score>best_score: best_score=score; best=c
            return best
        state_col=_pick_state_col(df_src.columns)
        date_col=_pick_date_col(df_src)
        rt_col=_pick_rt_col(df_src)
        assert state_col and date_col and rt_col, f'Fallback file lacks required columns. Found -> state:{state_col}, date:{date_col}, rt:{rt_col}'
        rt_only=df_src[[state_col,date_col,rt_col]].rename(columns={state_col:'state',date_col:'date',rt_col:'rt'})
        os.makedirs('./cache', exist_ok=True)
        RT_CSV='./cache/rt_state_weekly.csv'
        rt_only.to_csv(RT_CSV, index=False)
        print('[fallback] Wrote Rt-only file:', RT_CSV, 'rows:', len(rt_only))

print('Final paths:\n  MOBILITY_CSV:', MOBILITY_CSV, '\n  RT_CSV:', RT_CSV)
assert RT_CSV is not None, 'R(t) CSV not found.'

mob=pd.read_csv(MOBILITY_CSV); rt=pd.read_csv(RT_CSV)
mob.columns=[c.strip().lower() for c in mob.columns]
rt.columns=[c.strip().lower() for c in rt.columns]

def _coalesce(cols, names):
    for n in names:
        if n in cols: return n
    return None

state_col_m=_coalesce(mob.columns,['state','state_abbr','st'])
date_col_m=_coalesce(mob.columns,['date','week','week_ending','week_ending_date'])

# robust mobility picker
import pandas as pd

def _pick_mob_col(df, manual=None):
    cols=list(df.columns)
    if manual and manual in cols:
        return manual
    common=['mob_pct','mobility_pct','pct_mobility','pct_vs_baseline','pct_change','mobility_dow_baseline',
            'percent_change','percent_change_from_baseline','pct','mobility','mob','value']
    for n in common:
        if n in cols: return n
    for c in cols:
        if re.search(r'(pct|percent).*base|mob', c):
            return c
    num=[c for c in cols if pd.api.types.is_numeric_dtype(df[c])]
    best=None; best_score=-1
    for c in num:
        s=pd.to_numeric(df[c], errors='coerce'); s=s[np.isfinite(s)]
        if s.empty: continue
        med=float(np.nanmedian(s)); std=float(np.nanstd(s))
        score=0
        if -200<=med<=200: score+=1
        if -3<=med<=3: score+=1
        if std>0: score+=0.1
        if score>best_score: best_score=score; best=c
    return best

mob_col=_pick_mob_col(mob, MOBILITY_VALUE_COL)
state_col_r=_coalesce(rt.columns,['state','state_abbr','st'])
date_col_r=_coalesce(rt.columns,['date','week','week_ending','week_ending_date'])
rt_col=_coalesce(rt.columns,['rt','r_t','r','re','r_eff'])

missing=[]
for need,nm in [(state_col_m,'mobility state'),(date_col_m,'mobility date'),(mob_col,'mobility value'),
                (state_col_r,'rt state'),(date_col_r,'rt date'),(rt_col,'rt value')]:
    if need is None: missing.append(nm)
if missing:
    print('Mobility columns present:', mob.columns.tolist())
    print('Rt columns present:', rt.columns.tolist())
    raise ValueError('Missing required columns: ' + ', '.join(missing))

mob=mob.rename(columns={state_col_m:'state', date_col_m:'date', mob_col:'mob_pct'})
rt =rt.rename(columns={state_col_r:'state',  date_col_r:'date',  rt_col:'rt'})

mob['date']=pd.to_datetime(mob['date']); rt['date']=pd.to_datetime(rt['date'])
if mob['mob_pct'].abs().max() <= 1.5:
    mob['mob_pct'] = mob['mob_pct'] * 100.0

df_weekly=(pd.merge(rt, mob, on=['state','date'], how='inner').sort_values(['state','date']).reset_index(drop=True))
df_weekly['season']=df_weekly['date'].dt.to_period('W').apply(lambda p: week_to_season(p.start_time))
print('Rows after merge:', len(df_weekly))
display(df_weekly.head(8))



## 4) EAKF/SIR hooks (optional)


In [None]:

EAKF_STATE_TRAJ = None
PARAM_PRIORS    = None

eakf_df=None
if EAKF_STATE_TRAJ and os.path.exists(EAKF_STATE_TRAJ):
    try:
        eakf_df=pd.read_csv(EAKF_STATE_TRAJ, parse_dates=['date'])
        preview(eakf_df,'EAKF state trajectory')
    except Exception as e:
        print('[warn] EAKF load failed:', e)
else:
    print('[info] No EAKF trajectory provided (optional).')

priors_df=None
if PARAM_PRIORS and os.path.exists(PARAM_PRIORS):
    try:
        priors_df=pd.read_csv(PARAM_PRIORS)
        preview(priors_df,'Parameter priors (Shaman/Yang/Lipsitch)')
    except Exception as e:
        print('[warn] Priors load failed:', e)



## 5) Figures


In [None]:

def plot_rt_vs_mobility_clean(state, rt_weekly, mob_weekly_pct, savepath, title_extra=""):
    df=(pd.concat({'rt': rt_weekly, 'mob': mob_weekly_pct}, axis=1).sort_index().dropna())
    if df.empty:
        print(f'[skip] {state}: no overlap after dropna()'); return
    s=df.rolling(3, min_periods=1, center=True).mean()
    fig, ax1 = plt.subplots(figsize=(8,4))
    ax1.plot(s.index, s['rt'], color='tab:blue', lw=2.6, label='R(t)')
    ax1.axhline(1.0, color='0.6', ls='--', lw=1)
    ax1.set_ylabel('R(t)', color='tab:blue'); ax1.tick_params(axis='y', colors='tab:blue')
    ax2 = ax1.twinx()
    ax2.plot(s.index, s['mob'], color='tab:orange', lw=2.0, label='%Δ mobility (baseline DoW)')
    ax2.axhline(0, color='0.6', ls='--', lw=1)
    ax2.set_ylabel('Mobility (%Δ vs baseline)', color='tab:orange'); ax2.tick_params(axis='y', colors='tab:orange')
    ax1.set_title(f'{state}: R(t) vs Mobility{(" — "+title_extra) if title_extra else ""}')
    ax1.set_xlabel('Week')
    h1,l1=ax1.get_legend_handles_labels(); h2,l2=ax2.get_legend_handles_labels()
    ax1.legend(h1+h2,l1+l2,loc='upper left',frameon=False)
    fig.tight_layout(); fig.savefig(savepath, dpi=300, bbox_inches='tight'); plt.close(fig)


def build_state_summaries(df_weekly, mode='MEDMED', drop_covid=True):
    d=df_weekly.copy()
    if drop_covid:
        d=d[~d['season'].astype(str).str.contains('2020-21')]
    if mode.upper()=='MEDMED':
        within=d.groupby(['state','season']).agg(rt_stat=('rt','median'), mob_stat=('mob_pct','median'))
        across=within.groupby('state').agg(rt_val=('rt_stat','median'), mob_val=('mob_stat','median'), n_seasons=('rt_stat','size')).reset_index()
    elif mode.upper()=='P90P90':
        within=d.groupby(['state','season']).agg(rt_stat=('rt', lambda x: np.nanpercentile(x,90)),
                                                mob_stat=('mob_pct', lambda x: np.nanpercentile(x,90)))
        across=within.groupby('state').agg(rt_val=('rt_stat', lambda x: np.nanpercentile(x,90)),
                                          mob_val=('mob_stat', lambda x: np.nanpercentile(x,90)),
                                          n_seasons=('rt_stat','size')).reset_index()
    else:
        raise ValueError("mode must be 'MEDMED' or 'P90P90'")
    return across.dropna(subset=['rt_val','mob_val']).sort_values('state')


def spearman_rho(x,y):
    try:
        from scipy.stats import spearmanr
        rho,p=spearmanr(x,y,nan_policy='omit')
    except Exception:
        xr=pd.Series(x).rank(); yr=pd.Series(y).rank()
        rho=float(np.corrcoef(xr,yr)[0,1]); p=np.nan
    return rho,p


def plot_state_scatter_simple(tbl, xlab, ylab, title, savepath, annotate=True, draw_fit=True, alpha=0.9):
    if tbl.empty:
        print('[warn] empty table passed to scatter:', savepath); return
    x=tbl['mob_val'].values.astype(float); y=tbl['rt_val'].values.astype(float)
    rho,p=spearman_rho(x,y)
    fig,ax=plt.subplots(figsize=(7,6))
    ax.scatter(x,y,s=45,alpha=alpha)
    if draw_fit and len(x)>=2 and np.isfinite(x).all() and np.isfinite(y).all():
        try:
            m,b=np.polyfit(x,y,1)
            xs=np.linspace(np.nanmin(x),np.nanmax(x),100)
            ax.plot(xs,m*xs+b,lw=1.2,alpha=0.5)
        except Exception: pass
    ax.set_xlabel(xlab); ax.set_ylabel(ylab)
    ax.set_title(f"{title}\nSpearman ρ={rho:.2f},  p={p:.3g}")
    ax.grid(True,alpha=0.25)
    if annotate:
        for _,r in tbl.iterrows():
            ax.annotate(r['state'],(r['mob_val'],r['rt_val']),xytext=(4,2),textcoords='offset points',fontsize=8)
    fig.tight_layout(); fig.savefig(savepath, dpi=300, bbox_inches='tight'); plt.close(fig)


In [None]:

# Figure 1
states=sorted(df_weekly['state'].dropna().unique())
count=0
for st in states:
    d=df_weekly[df_weekly['state']==st].set_index('date')
    if d[['rt','mob_pct']].dropna().empty:
        print(f'[skip] {st}: no overlapping data'); continue
    out_png=os.path.join(OUT_DIR, f'fig1_{st}_rt_vs_mob.png')
    plot_rt_vs_mobility_clean(st, d['rt'], d['mob_pct'], out_png)
    count+=1
print(f'[ok] Figure 1 panels saved: {count}  -> {OUT_DIR}')

# Figure 2 (PRIMARY & APPENDIX)
tbl_med=build_state_summaries(df_weekly,mode='MEDMED',drop_covid=True)
plot_state_scatter_simple(tbl_med,
    xlab='Mobility (%Δ vs baseline) — seasonal MEDIAN across seasons',
    ylab='R(t) — seasonal MEDIAN across seasons',
    title='States: R(t) (typical) vs Mobility (typical) — PRIMARY (excl. 2020–21)',
    savepath=os.path.join(OUT_DIR,'fig2_primary_MEDMED.png'), annotate=True, draw_fit=True)

tbl_p90=build_state_summaries(df_weekly,mode='P90P90',drop_covid=True)
plot_state_scatter_simple(tbl_p90,
    xlab='Mobility (%Δ vs baseline) — seasonal p90 across seasons',
    ylab='R(t) — seasonal p90 across seasons',
    title='States: R(t) (high) vs Mobility (high) — PRIMARY (excl. 2020–21)',
    savepath=os.path.join(OUT_DIR,'fig2_primary_P90P90.png'), annotate=True, draw_fit=True)

tbl_med_all=build_state_summaries(df_weekly,mode='MEDMED',drop_covid=False)
plot_state_scatter_simple(tbl_med_all,
    xlab='Mobility (%Δ vs baseline) — seasonal MEDIAN across seasons',
    ylab='R(t) — seasonal MEDIAN across seasons',
    title='States: R(t) (typical) vs Mobility (typical) — APPENDIX (incl. 2020–21)',
    savepath=os.path.join(OUT_DIR,'figA_all_MEDMED.png'), annotate=True, draw_fit=True)

tbl_p90_all=build_state_summaries(df_weekly,mode='P90P90',drop_covid=False)
plot_state_scatter_simple(tbl_p90_all,
    xlab='Mobility (%Δ vs baseline) — seasonal p90 across seasons',
    ylab='R(t) — seasonal p90 across seasons',
    title='States: R(t) (high) vs Mobility (high) — APPENDIX (incl. 2020–21)',
    savepath=os.path.join(OUT_DIR,'figA_all_P90P90.png'), annotate=True, draw_fit=True)

print('[ok] Figure 2 panels saved ->', OUT_DIR)



## 6) Run summary


In [None]:

from pathlib import Path
pngs=sorted(Path(OUT_DIR).glob('*.png'))
print(f'PNG outputs in {OUT_DIR}:')
for p in pngs: print(' -', p.name)
