In [29]:
# =============================================================================
# FluSight — REAL-TIME FINAL DLM RECURSIVE FORECAST (+EHR POS + EHR INPATIENT LAGS)
# SUBMISSION FILE ONLY — SC + US
#
# What this script does:
#   ✅ Loads optimal DLM(+EHR pos+inp) params JSON from your Box folder (SC + US keys)
#   ✅ Loads Prisma+MUSC weekly influenza EHR series (SC)
#   ✅ Uses latest FluSight target-hospital-admissions from GitHub (SC + US)
#   ✅ REAL-TIME forecast:
#        - ORIGIN = last observed target_end_date in each location series
#        - reference_date = ORIGIN + 7 days (h=0 target_end_date)
#        - Fit model on ALL available data (1-step pairs) for each location
#        - Forecast h=0..3 (1..4 weeks ahead) recursively
#        - Quantiles via Normal approx using se_mean, inverse-transform, clip0, ROUND int
#   ✅ Saves:
#        - One combined submission CSV for SC + US
#        - One plot per location (observed full series + 50%/95% forecast bands)
#
# Requirements:
#   pip install pandas numpy matplotlib statsmodels scipy
# =============================================================================

import os
import glob
import json
import warnings
import time
import shutil

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.stats import norm
from statsmodels.tsa.statespace.structural import UnobservedComponents

warnings.filterwarnings("ignore")

# =============================================================================
# CONFIG
# =============================================================================
base_dir = r"C:\Users\mdsakhh\Box\BoxPHI-PHMR Projects\Sakhawat\FluSight_Forecast"

FILE_HOSP = (
    "https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/"
    "main/target-data/target-hospital-admissions.csv"
)

# --- EHR Data paths ---
PRISMA_DIR = r"C:\Users\mdsakhh\Box\BoxPHI-PHMR Projects\Data\Prisma Health\Infectious Disease EHR\Weekly Data\Latest Weekly Data"
PRISMA_BASENAME = "Prisma_Health_Weekly_Influenza_State_dx_Incident"
MUSC_DIR = r"C:\Users\mdsakhh\Box\BoxPHI-PHMR Projects\Data\MUSC\Infectious Disease EHR\Weekly Data\Latest Weekly Data"
MUSC_BASENAME = "MUSC_Weekly_Influenza_State_dx_Incident"

APPLY_SC_EHR_TO_US = True

LOCATIONS = [
    dict(code="45", name="South Carolina"),
    dict(code="US", name="US"),
]

OUTCOME_MEASURE = "wk inc flu hosp"
TARGET_KEY = "hosp"
HORIZONS = [0, 1, 2, 3]

USE_LOG_TRANSFORM = True
MIN_TRAIN_ROWS = 60

SUBMISSION_QUANTILES = [
    0.01, 0.025, 0.05,
    0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50,
    0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90,
    0.95, 0.975, 0.99
]

ROUND_NDIGITS = 0

suffix = "log" if USE_LOG_TRANSFORM else "raw"

DLM_PARAMS_JSON = os.path.join(
    base_dir, "model_eval", "optimal_parameters",
    f"optimal_params_dlm_recursive_hosp_with_pos_inp_prisma_musc_{suffix}.json"
)

# Output
out_root = os.path.join(
    base_dir,
    "final_submission",
    f"dlm_recursive_hosp_with_pos_inp_prisma_musc_{suffix}",
    "realtime_latest_fit_all_data_SC_US"
)
out_plot_dir = os.path.join(out_root, "plots")

for d in [out_root, out_plot_dir]:
    try:
        os.makedirs(d, exist_ok=True)
    except Exception as e:
        print(f"WARNING: Could not create directory {d}: {e}")

# =============================================================================
# ROBUST FILE SAVING
# =============================================================================
def safe_save_csv(df, filepath, max_retries=3, retry_delay=2):
    """Safely save DataFrame to CSV with verification and retries."""
    for attempt in range(max_retries):
        try:
            os.makedirs(os.path.dirname(filepath), exist_ok=True)
            temp_path = filepath + f".tmp_{int(time.time())}"
            
            df.to_csv(temp_path, index=False, encoding='utf-8', lineterminator='\n')
            time.sleep(0.5)
            
            if os.path.exists(temp_path):
                file_size = os.path.getsize(temp_path)
                if file_size == 0:
                    print(f"  WARNING: Temp file is empty (attempt {attempt+1})")
                    if os.path.exists(temp_path):
                        os.remove(temp_path)
                    continue
                
                try:
                    test_df = pd.read_csv(temp_path)
                    if len(test_df) != len(df):
                        print(f"  WARNING: Row count mismatch (attempt {attempt+1})")
                        if os.path.exists(temp_path):
                            os.remove(temp_path)
                        continue
                except Exception as read_err:
                    print(f"  WARNING: Could not verify temp file (attempt {attempt+1}): {read_err}")
                    if os.path.exists(temp_path):
                        os.remove(temp_path)
                    continue
                
                if os.path.exists(filepath):
                    try:
                        os.remove(filepath)
                        time.sleep(0.3)
                    except Exception:
                        pass
                
                shutil.move(temp_path, filepath)
                time.sleep(0.5)
                
                if os.path.exists(filepath):
                    final_size = os.path.getsize(filepath)
                    if final_size > 0:
                        try:
                            final_df = pd.read_csv(filepath)
                            if len(final_df) == len(df):
                                print(f"  ✓ File saved and verified: {filepath}")
                                print(f"    Size: {final_size:,} bytes | Rows: {len(df)}")
                                return True
                        except Exception:
                            pass
                        
            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass
                    
        except PermissionError:
            print(f"  Attempt {attempt+1}/{max_retries}: Permission denied - file may be locked by Box sync")
            time.sleep(retry_delay)
        except Exception as e:
            print(f"  Attempt {attempt+1}/{max_retries}: Error saving file: {e}")
            time.sleep(retry_delay)
    
    return False


def save_with_fallback(df, primary_path, description="file"):
    """Try primary path, fall back to Desktop if fails."""
    print(f"\nSaving {description}...")
    print(f"  Target: {primary_path}")
    
    if safe_save_csv(df, primary_path):
        return primary_path
    
    print(f"  Primary save failed. Trying fallback location...")
    fallback_dir = os.path.join(os.path.expanduser("~"), "Desktop", "FluSight_submissions_fallback")
    os.makedirs(fallback_dir, exist_ok=True)
    fallback_path = os.path.join(fallback_dir, os.path.basename(primary_path))
    
    if safe_save_csv(df, fallback_path):
        print(f"  ✓ Saved to fallback: {fallback_path}")
        return fallback_path
    
    print(f"  Fallback also failed. Saving to current directory...")
    last_resort = os.path.basename(primary_path)
    try:
        df.to_csv(last_resort, index=False, encoding='utf-8')
        print(f"  ✓ Saved to current directory: {os.path.abspath(last_resort)}")
        return os.path.abspath(last_resort)
    except Exception as e:
        print(f"  ERROR: All save attempts failed: {e}")
        return None

# =============================================================================
# UTIL
# =============================================================================
def _latest_file_by_mtime(file_list):
    return max(file_list, key=lambda p: os.path.getmtime(p)) if file_list else None

def pick_col(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

def fmt_mdY(dt_like):
    dt = pd.to_datetime(dt_like)
    if isinstance(dt, pd.Series):
        s = dt.dt.strftime("%m/%d/%Y")
        s = s.str.replace(r"^0", "", regex=True)
        s = s.str.replace(r"/0", "/", regex=True)
        return s
    s = dt.strftime("%m/%d/%Y")
    mm, dd, yy = s.split("/")
    return f"{int(mm)}/{int(dd)}/{yy}"

def log1p_safe(x):
    return np.log1p(np.maximum(0.0, x))

def inv_log1p(x):
    return np.expm1(x)

def to_scalar(x):
    return float(np.asarray(x).ravel()[0])

def round_forecast_value(x, ndigits=0):
    x = float(max(0.0, x))
    if ndigits == 0:
        return int(round(x))
    return float(round(x, ndigits))

def _q_str(q):
    s = f"{q:.3f}".rstrip("0").rstrip(".")
    return s

# =============================================================================
# HOLIDAY FLAGS
# =============================================================================
def thanksgiving_flag(d):
    d = pd.Timestamp(d)
    if d.month != 11:
        return 0
    first = pd.Timestamp(year=d.year, month=11, day=1)
    offset = (3 - first.dayofweek) % 7
    first_thu = first + pd.Timedelta(days=offset)
    fourth_thu = first_thu + pd.Timedelta(days=21)
    return int(abs((d - fourth_thu).days) <= 3)

def christmas_flag(d):
    d = pd.Timestamp(d)
    return int(d.month == 12 and d.day >= 20)

def newyear_flag(d):
    d = pd.Timestamp(d)
    return int(d.month == 1 and d.day <= 7)

# =============================================================================
# LOAD OPTIMAL PARAMS
# =============================================================================
def load_optimal_params():
    if os.path.exists(DLM_PARAMS_JSON):
        path = DLM_PARAMS_JSON
    else:
        candidates = glob.glob(
            os.path.join(base_dir, "model_eval", "optimal_parameters",
                         f"optimal_params_dlm_recursive_hosp_with_pos_inp_prisma_musc_{suffix}*.json")
        )
        path = _latest_file_by_mtime(candidates)
        if path is None:
            raise FileNotFoundError(f"Could not find params JSON: {DLM_PARAMS_JSON}")
    with open(path, "r") as f:
        obj = json.load(f)
    print(f"✓ Loaded params: {path}")
    return obj, path

# =============================================================================
# LOAD TARGET SERIES
# =============================================================================
def load_target_series_hosp(url, location_code, outcome_measure):
    df = pd.read_csv(url)
    df["date"] = pd.to_datetime(df.get("target_end_date", df.get("date")))
    df["location"] = df["location"].astype(str)
    sub = df[df["location"] == str(location_code)].copy()
    if "outcome_measure" in sub.columns:
        sub = sub[sub["outcome_measure"] == outcome_measure].copy()
    sub["y_original"] = pd.to_numeric(sub["value"], errors="coerce")
    sub = sub[["date", "y_original"]].dropna().sort_values("date").drop_duplicates("date").reset_index(drop=True)
    sub["y"] = log1p_safe(sub["y_original"].values) if USE_LOG_TRANSFORM else sub["y_original"].astype(float)
    return sub

# =============================================================================
# LOAD EHR (Prisma + MUSC)
# =============================================================================
def load_ehr_covariates_sc():
    prisma_files = glob.glob(os.path.join(PRISMA_DIR, f"{PRISMA_BASENAME}*.csv"))
    musc_files = glob.glob(os.path.join(MUSC_DIR, f"{MUSC_BASENAME}*.csv"))

    def _load_one(path, tag):
        df = pd.read_csv(path)
        if "State" in df.columns:
            df = df[df["State"] == "SC"].copy()
        wk_col = pick_col(df, ["Week", "week", "Week_End", "week_end", "WeekEnd", "date", "Date"])
        if wk_col is None:
            raise ValueError(f"{tag}: week/date column not found")
        df[wk_col] = pd.to_datetime(df[wk_col])
        pos_col = pick_col(df, ["Weekly_Positive_Tests", "Weekly_Positive", "Positive_Tests", "Weekly_Positive_Tests_All"])
        inp_col = pick_col(df, ["Weekly_Inpatient_Hospitalizations", "Weekly_Inpatient_Hosp", "Inpatient_Hospitalizations"])
        if pos_col is None:
            df["__pos__"] = 0.0
            pos_col = "__pos__"
        if inp_col is None:
            df["__inp__"] = 0.0
            inp_col = "__inp__"
        out = df[[wk_col, pos_col, inp_col]].copy()
        out.columns = ["date", f"{tag}_pos", f"{tag}_inp"]
        out[f"{tag}_pos"] = pd.to_numeric(out[f"{tag}_pos"], errors="coerce").fillna(0.0)
        out[f"{tag}_inp"] = pd.to_numeric(out[f"{tag}_inp"], errors="coerce").fillna(0.0)
        return out

    pf, mf = _latest_file_by_mtime(prisma_files), _latest_file_by_mtime(musc_files)
    prisma_df = _load_one(pf, "prisma") if pf else None
    musc_df = _load_one(mf, "musc") if mf else None

    if prisma_df is None and musc_df is None:
        print("WARNING: No EHR files found.")
        return None

    if prisma_df is not None and musc_df is not None:
        ehr = pd.merge(prisma_df, musc_df, on="date", how="outer")
    elif prisma_df is not None:
        ehr = prisma_df.copy()
        ehr["musc_pos"], ehr["musc_inp"] = 0.0, 0.0
    else:
        ehr = musc_df.copy()
        ehr["prisma_pos"], ehr["prisma_inp"] = 0.0, 0.0

    ehr = ehr.fillna(0.0).sort_values("date").reset_index(drop=True)
    ehr["pos"] = ehr["prisma_pos"] + ehr["musc_pos"]
    ehr["inp"] = ehr["prisma_inp"] + ehr["musc_inp"]
    if USE_LOG_TRANSFORM:
        ehr["pos"] = log1p_safe(ehr["pos"].values)
        ehr["inp"] = log1p_safe(ehr["inp"].values)
    print(f"Combined EHR: {len(ehr)} weeks | {ehr['date'].min().date()} to {ehr['date'].max().date()}")
    return ehr[["date", "pos", "inp"]].copy()

# =============================================================================
# BUILD 1-STEP DATASET
# =============================================================================
def build_one_step_dataset_with_ehr(series_df, ehr_df, y_lags, pos_lags, inp_lags):
    df = series_df.copy().sort_values("date").reset_index(drop=True)
    if ehr_df is not None:
        df = pd.merge(df, ehr_df, on="date", how="left")
        df["pos"] = df["pos"].fillna(0.0)
        df["inp"] = df["inp"].fillna(0.0)
    else:
        df["pos"], df["inp"] = 0.0, 0.0

    rows = []
    for i in range(len(df) - 1):
        origin_date, target_date = df.loc[i, "date"], df.loc[i + 1, "date"]
        y_t, y_o = float(df.loc[i + 1, "y"]), float(df.loc[i + 1, "y_original"])
        feats = {}
        for L in y_lags:
            j = i - L + 1
            feats[f"y_lag{L}"] = float(df.loc[j, "y"]) if j >= 0 else float(df.loc[0, "y"])
        for L in pos_lags:
            j = i - L + 1
            feats[f"pos_lag{L}"] = float(df.loc[j, "pos"]) if j >= 0 else float(df.loc[0, "pos"])
        for L in inp_lags:
            j = i - L + 1
            feats[f"inp_lag{L}"] = float(df.loc[j, "inp"]) if j >= 0 else float(df.loc[0, "inp"])
        feats["is_thanksgiving"] = thanksgiving_flag(target_date)
        feats["is_christmas"] = christmas_flag(target_date)
        feats["is_newyear"] = newyear_flag(target_date)
        row = {"origin_date": origin_date, "target_end_date": target_date, "y": y_t, "y_original": y_o}
        row.update(feats)
        rows.append(row)

    ds = pd.DataFrame(rows)
    exog_cols = sorted(
        [c for c in ds.columns if c.startswith(("y_lag", "pos_lag", "inp_lag", "is_"))],
        key=lambda x: (0 if x.startswith("y_lag") else 1 if x.startswith("pos_lag") else 2 if x.startswith("inp_lag") else 3, x)
    )
    return ds, exog_cols, df

# =============================================================================
# FIT DLM
# =============================================================================
def fit_dlm(endog, exog, structure):
    model = UnobservedComponents(endog=endog, exog=exog, level=structure["level"],
                                  trend=structure["trend"], seasonal=structure["seasonal"])
    return model.fit(disp=False)

# =============================================================================
# REAL-TIME FORECAST
# =============================================================================
def realtime_recursive_forecast(series_df, ehr_df, cfg, round_ndigits=0):
    y_lags, pos_lags, inp_lags = list(cfg["y_lags"]), list(cfg["pos_lags"]), list(cfg["inp_lags"])
    structure = cfg["structure"]
    series_df = series_df.copy().sort_values("date").reset_index(drop=True)

    origin = pd.to_datetime(series_df["date"].max())
    reference_date = origin + pd.Timedelta(days=7)

    ds1, exog_cols, df_full = build_one_step_dataset_with_ehr(series_df, ehr_df, y_lags, pos_lags, inp_lags)
    if len(ds1) < MIN_TRAIN_ROWS:
        raise ValueError(f"Not enough rows: {len(ds1)} < {MIN_TRAIN_ROWS}")

    res = fit_dlm(ds1["y"].values.astype(float), ds1[exog_cols].values.astype(float), structure)

    dates = pd.to_datetime(df_full["date"])
    origin_idx = int(df_full.index[dates == origin][0])

    y_hist = df_full["y"].values.astype(float)[: origin_idx + 1].copy()
    pos_hist = df_full["pos"].values.astype(float)[: origin_idx + 1].copy() if "pos" in df_full.columns else np.zeros(origin_idx + 1)
    inp_hist = df_full["inp"].values.astype(float)[: origin_idx + 1].copy() if "inp" in df_full.columns else np.zeros(origin_idx + 1)

    def lag_value(hist, idx, L):
        j = idx - L + 1
        return float(hist[j]) if j >= 0 else float(hist[0])

    out = []
    for h in HORIZONS:
        step_k = h + 1
        target_date = origin + pd.Timedelta(days=7 * step_k)
        cur_idx = origin_idx + (step_k - 1)

        ex = {}
        for L in y_lags:
            ex[f"y_lag{L}"] = lag_value(y_hist, cur_idx, L)
        for L in pos_lags:
            ex[f"pos_lag{L}"] = lag_value(pos_hist, cur_idx, L)
        for L in inp_lags:
            ex[f"inp_lag{L}"] = lag_value(inp_hist, cur_idx, L)
        ex["is_thanksgiving"] = thanksgiving_flag(target_date)
        ex["is_christmas"] = christmas_flag(target_date)
        ex["is_newyear"] = newyear_flag(target_date)

        x_row = np.array([[ex.get(c, 0.0) for c in exog_cols]], dtype=float)
        fc = res.get_forecast(steps=1, exog=x_row)
        mu = to_scalar(fc.predicted_mean)
        try:
            se = to_scalar(fc.se_mean)
            if not np.isfinite(se):
                se = 0.0
        except:
            se = 0.0

        y_hist = np.append(y_hist, mu)
        pos_hist = np.append(pos_hist, float(pos_hist[-1]) if len(pos_hist) else 0.0)
        inp_hist = np.append(inp_hist, float(inp_hist[-1]) if len(inp_hist) else 0.0)

        q_to_pred = {}
        for q in SUBMISSION_QUANTILES:
            z = norm.ppf(q)
            xq = mu + z * se
            yq = inv_log1p(xq) if USE_LOG_TRANSFORM else xq
            q_to_pred[q] = round_forecast_value(yq, ndigits=round_ndigits)

        out.append({"horizon": int(h), "target_end_date": target_date, "q_to_pred": q_to_pred})

    return origin, reference_date, out

# =============================================================================
# SUBMISSION ROWS
# =============================================================================
def make_submission_rows(reference_date, target, horizon, target_end_date, location, q_to_value, round_ndigits=0):
    rows = []
    for q in SUBMISSION_QUANTILES:
        rows.append({
            "reference_date": fmt_mdY(reference_date),
            "target": target,
            "horizon": int(horizon),
            "target_end_date": fmt_mdY(target_end_date),
            "location": str(location),
            "output_type": "quantile",
            "output_type_id": _q_str(q),
            "value": round_forecast_value(q_to_value[q], ndigits=round_ndigits),
        })
    return rows

# =============================================================================
# PLOT
# =============================================================================
def save_forecast_plot(loc_name, observed_df, origin_date, reference_date, forecast_df, fig_path):
    obs = observed_df.sort_values("date").copy()
    f = forecast_df.sort_values("target_end_date").copy()

    fig = plt.figure(figsize=(14, 7))
    plt.plot(pd.to_datetime(obs["date"]), obs["y_original"].values, linewidth=2.0, marker="o", markersize=3, label="Observed")
    plt.fill_between(pd.to_datetime(f["target_end_date"]), f["q0.025"].values, f["q0.975"].values, alpha=0.20, label="95% interval")
    plt.fill_between(pd.to_datetime(f["target_end_date"]), f["q0.25"].values, f["q0.75"].values, alpha=0.35, label="50% interval")
    plt.plot(pd.to_datetime(f["target_end_date"]), f["q0.50"].values, linewidth=2.5, marker="s", markersize=5, label="Forecast (median)")
    plt.axvline(pd.to_datetime(origin_date), linestyle="--", linewidth=2.0, alpha=0.6, label="Origin")
    plt.title(f"{loc_name} | {OUTCOME_MEASURE}\nDLM recursive h=0..3 | origin={origin_date.date()} | ref={reference_date.date()}", fontsize=12)
    plt.xlabel("Date")
    plt.ylabel("Hospital admissions")
    plt.grid(True, alpha=0.25)
    plt.legend(loc="best")
    plt.tight_layout()

    try:
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        fig.savefig(fig_path, dpi=200, bbox_inches="tight")
        print(f"✓ Saved plot: {fig_path}")
    except Exception as e:
        print(f"WARNING: Could not save plot: {e}")
    plt.close()

# =============================================================================
# MAIN
# =============================================================================
print("\n" + "="*100)
print("REAL-TIME DLM RECURSIVE FORECAST — SUBMISSION FILE ONLY")
print("="*100)

params_obj, params_path = load_optimal_params()
ehr_sc = load_ehr_covariates_sc()

loc_to_results = {}

for loc in LOCATIONS:
    loc_code, loc_name = str(loc["code"]), loc["name"]

    print("\n" + "="*100)
    print(f"LOCATION: {loc_name} ({loc_code})")
    print("="*100)

    series = load_target_series_hosp(FILE_HOSP, loc_code, OUTCOME_MEASURE)
    print(f"Data: {len(series)} weeks | {series['date'].min().date()} to {series['date'].max().date()}")

    ehr_use = ehr_sc if (loc_code == "45" or APPLY_SC_EHR_TO_US) else None
    if loc_code != "45" and ehr_use is not None:
        print("NOTE: Using SC EHR series for US.")

    key = f"{loc_code}_{TARGET_KEY}_dlm_recursive_with_pos_inp_prisma_musc"
    if key not in params_obj:
        raise KeyError(f"Missing params key: {key}")

    cfg0 = params_obj[key]
    best_cfg = {"y_lags": list(cfg0["y_lags"]), "pos_lags": list(cfg0["pos_lags"]),
                "inp_lags": list(cfg0["inp_lags"]), "structure": cfg0["structure"]}

    print(f"Config: y_lags={best_cfg['y_lags']}, pos_lags={best_cfg['pos_lags']}, inp_lags={best_cfg['inp_lags']}")

    origin_date, reference_date, fc_list = realtime_recursive_forecast(series, ehr_use, best_cfg, ROUND_NDIGITS)

    print(f"\nOrigin: {origin_date.date()} | Reference date: {reference_date.date()}")

    plot_rows, loc_submit_rows = [], []
    for rec in fc_list:
        h, ted, q_to_pred = rec["horizon"], rec["target_end_date"], rec["q_to_pred"]
        loc_submit_rows.extend(make_submission_rows(reference_date, OUTCOME_MEASURE, h, ted, loc_code, q_to_pred, ROUND_NDIGITS))
        plot_rows.append({"horizon": h, "target_end_date": ted, "q0.025": q_to_pred[0.025], "q0.25": q_to_pred[0.25],
                          "q0.50": q_to_pred[0.50], "q0.75": q_to_pred[0.75], "q0.975": q_to_pred[0.975]})
        print(f"  h={h} @ {ted.date()} | median={q_to_pred[0.50]} | 50%CI=({q_to_pred[0.25]},{q_to_pred[0.75]}) | 95%CI=({q_to_pred[0.025]},{q_to_pred[0.975]})")

    # Plot
    forecast_plot_df = pd.DataFrame(plot_rows)
    fig_path = os.path.join(out_plot_dir, f"{loc_code}_{TARGET_KEY}_dlm_realtime_ref_{reference_date.date()}.png")
    save_forecast_plot(loc_name, series, origin_date, reference_date, forecast_plot_df, fig_path)

    loc_to_results[loc_code] = {"loc_name": loc_name, "origin_date": origin_date,
                                 "reference_date": reference_date, "submit_rows": loc_submit_rows}

# =============================================================================
# SAVE SUBMISSION CSV
# =============================================================================
print("\n" + "-"*80)
print("SAVING SUBMISSION FILE")
print("-"*80)

ref_dates = [pd.to_datetime(loc_to_results[c]["reference_date"]) for c in loc_to_results]
all_same_ref = all(ref_dates[0] == d for d in ref_dates) if ref_dates else False

if all_same_ref and len(loc_to_results) == 2:
    combined_rows = []
    for loc_code in loc_to_results:
        combined_rows.extend(loc_to_results[loc_code]["submit_rows"])

    submit_df = pd.DataFrame(combined_rows)
    submit_df = submit_df[["reference_date", "target", "horizon", "target_end_date", "location", "output_type", "output_type_id", "value"]]
    submit_df = submit_df.sort_values(["location", "horizon", "target_end_date", "output_type_id"]).reset_index(drop=True)

    ref_date = ref_dates[0].date()
    out_csv = os.path.join(out_root, f"FluSight_submission_DLM_EHR_SC_US_ref_{ref_date}.csv")
    saved_path = save_with_fallback(submit_df, out_csv, "submission CSV (SC+US)")

    print("\n" + "="*100)
    print("✓ SUBMISSION READY (SC+US)")
    print(f"Reference date: {ref_date}")
    print(f"Rows: {len(submit_df)} (2 locations × {len(HORIZONS)} horizons × 23 quantiles)")
    if saved_path:
        print(f"Saved to: {saved_path}")
    print("="*100)
else:
    for loc_code in loc_to_results:
        submit_df = pd.DataFrame(loc_to_results[loc_code]["submit_rows"])
        submit_df = submit_df[["reference_date", "target", "horizon", "target_end_date", "location", "output_type", "output_type_id", "value"]]
        submit_df = submit_df.sort_values(["horizon", "target_end_date", "output_type_id"]).reset_index(drop=True)
        ref_date = pd.to_datetime(loc_to_results[loc_code]["reference_date"]).date()
        out_csv = os.path.join(out_root, f"FluSight_submission_DLM_EHR_{loc_code}_ref_{ref_date}.csv")
        save_with_fallback(submit_df, out_csv, f"submission CSV ({loc_code})")

print("\nDONE")


REAL-TIME DLM RECURSIVE FORECAST — SUBMISSION FILE ONLY
✓ Loaded params: C:\Users\mdsakhh\Box\BoxPHI-PHMR Projects\Sakhawat\FluSight_Forecast\model_eval\optimal_parameters\optimal_params_dlm_recursive_hosp_with_pos_inp_prisma_musc_log.json
Combined EHR: 487 weeks | 2016-10-08 to 2026-01-31

LOCATION: South Carolina (45)
Data: 209 weeks | 2022-02-05 to 2026-01-31
Config: y_lags=[1, 2, 3, 4, 5, 6, 7], pos_lags=[5], inp_lags=[5, 6, 7, 8, 9]

Origin: 2026-01-31 | Reference date: 2026-02-07
  h=0 @ 2026-02-07 | median=185 | 50%CI=(140,245) | 95%CI=(82,417)
  h=1 @ 2026-02-14 | median=113 | 50%CI=(85,149) | 95%CI=(50,255)
  h=2 @ 2026-02-21 | median=66 | 50%CI=(50,88) | 95%CI=(29,150)
  h=3 @ 2026-02-28 | median=47 | 50%CI=(35,63) | 95%CI=(20,107)
✓ Saved plot: C:\Users\mdsakhh\Box\BoxPHI-PHMR Projects\Sakhawat\FluSight_Forecast\final_submission\dlm_recursive_hosp_with_pos_inp_prisma_musc_log\realtime_latest_fit_all_data_SC_US\plots\45_hosp_dlm_realtime_ref_2026-02-07.png

LOCATION: US (US)