<a href="https://colab.research.google.com/github/debashisdotchatterjee/Indian-Dengu-GAMLSS-Approach-1/blob/main/Indian_Dengu_GAMLSS_Approach_Draft%201.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%pip install caas-jupyter-tools

[31mERROR: Could not find a version that satisfies the requirement caas-jupyter-tools (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for caas-jupyter-tools[0m[31m
[0m

In [10]:
# ============================================================
# Dengue India 2019–2025 (NCVBDC wide sheet): Colab-ready
# ============================================================

# ---- Setup (versions chosen for Colab stability) ----
!pip -q install pandas==2.2.2 statsmodels==0.14.2 patsy==0.5.6 openpyxl==3.1.5

import os, re, zipfile
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import statsmodels.api as sm
import statsmodels.formula.api as smf

plt.rcParams.update({
    "figure.dpi": 120, "savefig.dpi": 150,
    "axes.grid": True, "grid.alpha": 0.25,
})

# ---- Paths ----
DATA_PATH = "dengue_india_state_2019_2025.xlsx"  # << set this if your path differs
OUT_DIR   = Path("outputs"); OUT_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR = OUT_DIR / "plots"; PLOTS_DIR.mkdir(exist_ok=True)
TABLES_DIR= OUT_DIR / "tables"; TABLES_DIR.mkdir(exist_ok=True)

# ============================================================
# 1) Load NCVBDC sheet (wide or long) and tidy to panel
# ============================================================

def parse_exposure_from_provisional(val):
    """Try to parse month count from 'Provisional_Upto' column, else default 8/12."""
    if pd.isna(val):
        return 8/12
    s = str(val).strip()
    # Examples it might contain: 'Aug', 'August', 'Upto Aug', 'Jan-Aug', '8', '8/12'
    month_map = {'jan':1,'january':1,'feb':2,'february':2,'mar':3,'march':3,'apr':4,'april':4,
                 'may':5,'jun':6,'june':6,'jul':7,'july':7,'aug':8,'august':8,'sep':9,'september':9,
                 'oct':10,'october':10,'nov':11,'november':11,'dec':12,'december':12}
    # numeric month present?
    mnum = re.findall(r'(?<!\d)(\d{1,2})(?!\d)', s)
    if mnum:
        try:
            m = int(mnum[-1])
            if 1 <= m <= 12:
                return m/12
        except:
            pass
    # month name present?
    tokens = re.findall(r'[A-Za-z]+', s.lower())
    for tok in tokens:
        if tok in month_map:
            return month_map[tok]/12
    # range like Jan-Aug
    rng = re.findall(r'([A-Za-z]{3,9})\s*[-–]\s*([A-Za-z]{3,9})', s)
    if rng:
        end = rng[0][1].lower()
        if end in month_map:
            return month_map[end]/12
    return 8/12

def load_ncvbdc_sheet(path):
    df = pd.read_excel(path, sheet_name=0, engine="openpyxl")
    # Normalize column names (keep originals for reference)
    cols_lower = {c: c.lower().strip() for c in df.columns}
    inv = {v:k for k,v in cols_lower.items()}

    # Detect state column
    state_col = None
    for key in ["affected states/uts", "state/ut", "state", "states", "name"]:
        if key in inv:
            state_col = inv[key]; break
    if state_col is None:
        raise ValueError(f"Could not find a State column. Found columns: {list(df.columns)}")

    # Provisional marker column (optional)
    prov_col = None
    for key in ["provisional_upto", "provisional upto", "provisional", "upto"]:
        if key in inv:
            prov_col = inv[key]; break

    # Identify wide year columns like '2019_C','2019_D'
    year_pat = re.compile(r'^(20\d{2})_([CDcd])$')
    wide_cols = []
    for c in df.columns:
        m = year_pat.match(str(c).strip())
        if m:
            wide_cols.append(c)

    if wide_cols:
        # WIDE -> LONG
        # Keep just the needed columns
        keep = [state_col] + ([prov_col] if prov_col else []) + wide_cols
        w = df[keep].copy()
        w[state_col] = w[state_col].astype(str).str.strip()

        # Melt all year_* columns
        long = w.melt(id_vars=[state_col] + ([prov_col] if prov_col else []),
                      value_vars=wide_cols, var_name="year_metric", value_name="value")
        # Split year and metric
        ym = long["year_metric"].astype(str).str.extract(r'^(20\d{2})_([CDcd])$')
        long["Year"] = ym[0].astype(int)
        long["Metric"] = ym[1].str.upper()

        # Spread back to Cases/Deaths
        panel = long.pivot_table(index=[state_col, "Year"] + ([prov_col] if prov_col else []),
                                 columns="Metric", values="value", aggfunc="first").reset_index()
        panel = panel.rename(columns={state_col:"State","C":"Cases","D":"Deaths"})
        # Clean numeric (NR etc.)
        for v in ["Cases","Deaths"]:
            panel[v] = pd.to_numeric(panel[v], errors="coerce")

        # Exposure
        panel["Exposure"] = 1.0
        if prov_col:
            # Only apply to 2025 if a provisional marker exists
            sel = panel["Year"].eq(2025)
            panel.loc[sel, "Exposure"] = panel.loc[sel, prov_col].map(parse_exposure_from_provisional)
            # If parser failed, default to 8/12
            panel.loc[sel & panel["Exposure"].isna(), "Exposure"] = 8/12
        else:
            panel.loc[panel["Year"].eq(2025), "Exposure"] = 8/12

        # Drop helper cols
        if prov_col:
            panel = panel.drop(columns=[prov_col])

        # Remove any national TOTAL rows if present
        panel["State_clean"] = panel["State"].str.upper().str.strip()
        total_mask = panel["State_clean"].isin(["TOTAL","INDIA TOTAL","NATIONAL TOTAL","NATIONAL"])
        panel = panel.loc[~total_mask].drop(columns=["State_clean"]).reset_index(drop=True)

        # Sort, lag
        panel = panel.sort_values(["State","Year"]).reset_index(drop=True)
        panel["Cases_lag"] = panel.groupby("State")["Cases"].shift(1)
        panel["log_cases_lag1p"] = np.log1p(panel["Cases_lag"])
        panel["log_cases_1p"] = np.log1p(panel["Cases"])

        return panel

    else:
        # Already LONG – try to detect columns
        possible = {c.lower().strip(): c for c in df.columns}
        def pick(keys):
            for k in keys:
                if k in possible: return possible[k]
            return None

        ycol = pick(["year","yr"])
        ccol = pick(["cases","case"])
        dcol = pick(["deaths","death"])
        if not all([ycol, ccol, dcol]):
            raise ValueError(f"Could not auto-detect columns for long format. Found: {df.columns.tolist()}")

        panel = df.rename(columns={state_col:"State", ycol:"Year", ccol:"Cases", dcol:"Deaths"}).copy()
        for v in ["Cases","Deaths"]:
            panel[v] = pd.to_numeric(panel[v], errors="coerce")
        panel["Year"] = panel["Year"].astype(int)
        panel["Exposure"] = np.where(panel["Year"].eq(2025), 8/12, 1.0)
        panel = panel.sort_values(["State","Year"]).reset_index(drop=True)
        panel["Cases_lag"] = panel.groupby("State")["Cases"].shift(1)
        panel["log_cases_lag1p"] = np.log1p(panel["Cases_lag"])
        panel["log_cases_1p"] = np.log1p(panel["Cases"])
        return panel

panel = load_ncvbdc_sheet(DATA_PATH)
print("Loaded tidy panel:", panel.shape)
display(panel.head(10))

# Save cleaned panel
panel.to_csv(TABLES_DIR / "panel_clean.csv", index=False)

# ============================================================
# 2) Models: Poisson for cases; Binomial for CFR (deaths|cases)
# ============================================================

# ---- Cases GLM (Poisson with offset log Exposure) ----
cases_fit_df = panel.dropna(subset=["Cases","log_cases_lag1p","Exposure"]).copy()
formula_cases = "Cases ~ C(State) + bs(Year, df=4) + log_cases_lag1p"

cases_model = smf.glm(
    formula=formula_cases,
    data=cases_fit_df,
    family=sm.families.Poisson(),
    offset=np.log(cases_fit_df["Exposure"])
).fit(cov_type="HC0")

print("\n[Cases GLM] Summary (truncated):")
print(cases_model.summary().as_text()[:1500])

# Predict μ_hat for all rows
panel["mu_cases_hat"] = cases_model.predict(
    panel.assign(offset=np.log(panel["Exposure"]))
)

# ---- CFR GLM (Binomial on deaths proportion with weights=Cases) ----
cfr_df = panel.copy()
cfr_df["cfr_prop"] = np.where(cfr_df["Cases"]>0, cfr_df["Deaths"]/cfr_df["Cases"], 0.0)
cfr_df["wts"] = cfr_df["Cases"].clip(lower=0)

# Use same structure; points with Cases=0 get weight 0 and don't influence fit
formula_cfr = "cfr_prop ~ C(State) + bs(Year, df=4) + log_cases_1p"

cfr_model = smf.glm(
    formula=formula_cfr,
    data=cfr_df,
    family=sm.families.Binomial(),
    freq_weights=cfr_df["wts"]
).fit()

print("\n[CFR GLM] Summary (truncated):")
print(cfr_model.summary().as_text()[:1500])

# Predict π_hat everywhere
panel["pi_hat"] = cfr_model.predict(panel)
panel["mu_deaths_hat"] = panel["Cases"] * panel["pi_hat"]

# ============================================================
# 3) Plots
# ============================================================
def savefig(path):
    plt.tight_layout()
    plt.savefig(path, bbox_inches="tight")
    plt.show()

# (a) Observed vs Predicted cases
plt.figure(figsize=(6,4))
m = panel.dropna(subset=["Cases","mu_cases_hat"])
plt.scatter(m["Cases"], m["mu_cases_hat"], s=16, alpha=0.6)
lim = [0, max(m["Cases"].max(), m["mu_cases_hat"].max())*1.05]
plt.plot(lim, lim, ls="--", lw=1, color="black")
plt.xlabel("Observed cases"); plt.ylabel("Predicted cases (Poisson GLM)")
plt.title("Observed vs Predicted Dengue Cases")
savefig(PLOTS_DIR / "plot_cases_obs_vs_pred.png")

# (b) National totals: observed vs fitted
nat = panel.groupby("Year", as_index=False).agg(
    obs_cases=("Cases","sum"),
    fit_cases=("mu_cases_hat","sum")
).sort_values("Year")

plt.figure(figsize=(7,4))
plt.plot(nat["Year"], nat["obs_cases"], marker="o", label="Observed")
plt.plot(nat["Year"], nat["fit_cases"], marker="s", ls="--", label="Fitted")
plt.title("National Dengue Totals by Year (Observed vs Fitted)")
plt.xlabel("Year"); plt.ylabel("Total cases"); plt.legend()
savefig(PLOTS_DIR / "plot_national_cases_time.png")

# (c) CFR observed vs predicted (Cases>0)
with_cases = panel.query("Cases > 0").copy()
with_cases["cfr_obs"] = with_cases["Deaths"] / with_cases["Cases"]

plt.figure(figsize=(6,4))
plt.scatter(with_cases["cfr_obs"], with_cases["pi_hat"], s=16, alpha=0.6)
mx = max(with_cases["cfr_obs"].max(), with_cases["pi_hat"].max())*1.05 if len(with_cases)>0 else 1
plt.plot([0,mx],[0,mx], ls="--", lw=1, color="black")
plt.xlabel("Observed CFR"); plt.ylabel("Predicted CFR (Binomial GLM)")
plt.title("Observed vs Predicted CFR")
savefig(PLOTS_DIR / "plot_cfr_obs_vs_pred.png")

# (d) Top 15 by predicted cases in 2024 and 2025
def top15_for_year(df, year):
    sub = df[df["Year"]==year].copy()
    sub = sub.sort_values("mu_cases_hat", ascending=False).head(15)
    return sub

top2024 = top15_for_year(panel, 2024)
top2025 = top15_for_year(panel, 2025)

# Save tables
top2024[["State","Cases","mu_cases_hat","Deaths"]].to_csv(TABLES_DIR / "table_top15_2024_cases.csv", index=False)
top2025[["State","Cases","mu_cases_hat"]].to_csv(TABLES_DIR / "table_top15_2025_cases.csv", index=False)

# Barh for 2025
plt.figure(figsize=(7,6))
order = top2025.sort_values("mu_cases_hat")  # ascending for barh
plt.barh(order["State"], order["mu_cases_hat"])
plt.xlabel("Predicted cases (2025, exposure-adjusted)")
plt.title("Top 15 Predicted State/UT Burdens in 2025")
savefig(PLOTS_DIR / "plot_top15_2025_cases_barh.png")

# ============================================================
# 4) Save modeling outputs and LaTeX longtables
# ============================================================

# CFR ranking table
cfr_rank = with_cases.groupby("State", as_index=False).agg(
    observed_cfr=("cfr_obs","mean"),
    predicted_cfr=("pi_hat","mean"),
    avg_cases=("Cases","mean")
).sort_values(["predicted_cfr","observed_cfr"], ascending=[False, False])
cfr_rank.to_csv(TABLES_DIR / "table_state_cfr_ranking.csv", index=False)

# Per state-year predictions
cases_pred  = panel[["State","Year","Cases","Deaths","mu_cases_hat"]].copy()
deaths_pred = panel[["State","Year","Cases","Deaths","pi_hat","mu_deaths_hat"]].copy()
cases_pred.to_csv(TABLES_DIR / "cases_predictions.csv", index=False)
deaths_pred.to_csv(TABLES_DIR / "deaths_predictions.csv", index=False)
panel.to_csv(TABLES_DIR / "panel_with_predictions.csv", index=False)

# ---- LaTeX longtable writers ----
def latex_escape(s):
    if pd.isna(s): return ""
    s = str(s)
    repl = {'\\':'\\textbackslash{}','&':'\\&','%':'\\%','$':'\\$','#':'\\#','_':'\\_',
            '{':'\\{','}':'\\}','~':'\\textasciitilde{}','^':'\\textasciicircum{}'}
    for k,v in repl.items(): s = s.replace(k,v)
    return s

def to_longtable_tex(df, columns, headers, label, caption, align=None, float_fmt=None):
    df2 = df.copy()
    if float_fmt:
        for col, fmt in float_fmt.items():
            if col in df2.columns:
                df2[col] = df2[col].map(lambda x: "" if pd.isna(x) else fmt.format(x))
    if align is None:
        align = "l" + "r"*(len(columns)-1)
    lines = []
    lines += [f"\\begin{{longtable}}{{{align}}}",
              f"\\caption{{{caption}}}\\\\",
              f"\\label{{{label}}}\\\\",
              "\\toprule",
              " & ".join(headers) + " \\\\",
              "\\midrule",
              "\\endfirsthead",
              "\\toprule",
              " & ".join(headers) + " \\\\",
              "\\midrule",
              "\\endhead",
              "\\midrule",
              f"\\multicolumn{{{len(columns)}}}{{r}}{{\\emph{{Continued on next page}}}}\\\\",
              "\\bottomrule",
              "\\endfoot",
              "\\bottomrule",
              "\\endlastfoot"]
    for _, r in df2.iterrows():
        row = [latex_escape(r[c]) for c in columns]
        lines.append(" & ".join(row) + " \\\\")
    lines.append("\\end{longtable}")
    return "\n".join(lines)

# Panel longtable
panel_cols = ["State","Year","Cases","mu_cases_hat","Deaths","pi_hat","mu_deaths_hat"]
panel_hdrs = ["State/UT","Year","Cases","Pred. Cases","Deaths","Pred. CFR","Pred. Deaths"]
panel_fmt  = {"mu_cases_hat":"{:.1f}","mu_deaths_hat":"{:.1f}","pi_hat":"{:.4f}"}
panel_tex  = to_longtable_tex(
    panel[panel_cols], panel_cols, panel_hdrs,
    label="tab:panel-with-preds-manual",
    caption="State--year panel with observed counts and model predictions (manual longtable).",
    float_fmt=panel_fmt
)
(TABLES_DIR / "panel_with_predictions_table.tex").write_text(panel_tex)

# Cases longtable
cases_cols = ["State","Year","Cases","Deaths","mu_cases_hat"]
cases_hdrs = ["State/UT","Year","Cases","Deaths","Pred. Cases"]
cases_tex  = to_longtable_tex(
    cases_pred[cases_cols], cases_cols, cases_hdrs,
    label="tab:cases-preds-manual",
    caption="Cases model outputs by state--year (manual longtable).",
    float_fmt={"mu_cases_hat":"{:.1f}"}
)
(TABLES_DIR / "cases_predictions_table.tex").write_text(cases_tex)

# Deaths|cases longtable
deaths_cols = ["State","Year","Cases","Deaths","pi_hat","mu_deaths_hat"]
deaths_hdrs = ["State/UT","Year","Cases","Deaths","Pred. CFR","Pred. Deaths"]
deaths_tex  = to_longtable_tex(
    deaths_pred[deaths_cols], deaths_cols, deaths_hdrs,
    label="tab:deaths-preds-manual",
    caption="Deaths|cases model outputs (manual longtable).",
    float_fmt={"pi_hat":"{:.6f}","mu_deaths_hat":"{:.1f}"}
)
(TABLES_DIR / "deaths_predictions_table.tex").write_text(deaths_tex)

# ============================================================
# 5) Show a few tables in Colab and zip everything
# ============================================================
print("\nTop 15 by predicted cases (2024):")
display(top2024[["State","Cases","mu_cases_hat","Deaths"]])

print("\nTop 15 by predicted cases (2025):")
display(top2025[["State","Cases","mu_cases_hat"]])

print("\nCFR ranking (head):")
display(cfr_rank.head(15))

# ZIP
ZIP_PATH = "dengue_outputs_final.zip"
with zipfile.ZipFile(ZIP_PATH, "w", zipfile.ZIP_DEFLATED) as zf:
    for p in PLOTS_DIR.glob("*.png"):
        zf.write(p, arcname=f"plots/{p.name}")
    for t in TABLES_DIR.glob("*"):
        zf.write(t, arcname=f"tables/{t.name}")

print(f"\nAll outputs saved to: {OUT_DIR}")
print(f"ZIP archive ready: {ZIP_PATH}")


Loaded tidy panel: (257, 8)


Metric,State,Year,Cases,Deaths,Exposure,Cases_lag,log_cases_lag1p,log_cases_1p
0,A& N Island,2019,168.0,0.0,1.0,,,5.129899
1,A& N Island,2020,98.0,0.0,1.0,168.0,5.129899,4.59512
2,A& N Island,2021,175.0,0.0,1.0,98.0,4.59512,5.170484
3,A& N Island,2022,1014.0,3.0,1.0,175.0,5.170484,6.922644
4,A& N Island,2023,846.0,0.0,1.0,1014.0,6.922644,6.741701
5,A& N Island,2024,59.0,0.0,1.0,846.0,6.741701,4.094345
6,A& N Island,2025,254.0,0.0,0.666667,59.0,4.094345,5.541264
7,Andhra Pradesh,2019,5286.0,0.0,1.0,,,8.573006
8,Andhra Pradesh,2020,925.0,0.0,1.0,5286.0,8.573006,6.830874
9,Andhra Pradesh,2021,4760.0,0.0,1.0,925.0,6.830874,8.468213



[Cases GLM] Summary (truncated):
                 Generalized Linear Model Regression Results                  
Dep. Variable:                  Cases   No. Observations:                  220
Model:                            GLM   Df Residuals:                      178
Model Family:                 Poisson   Df Model:                           41
Link Function:                    Log   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:            -1.7965e+05
Date:                Mon, 20 Oct 2025   Deviance:                   3.5735e+05
Time:                        05:39:11   Pearson chi2:                 3.59e+05
No. Iterations:                     9   Pseudo R-squ. (CS):              1.000
Covariance Type:                  HC0                                         
                                       coef    std err          z      P>|z|      [0.025      0.975]
---------------------------------------------------------------------------

PatsyError: predict requires that you use a DataFrame when predicting from a model
that was created using the formula api.

The original error message returned by patsy is:
Error evaluating factor: NotImplementedError: some data points fall outside the outermost knots, and I'm not sure how to handle them. (Patches accepted!)
    Cases ~ C(State) + bs(Year, df=4) + log_cases_lag1p
                       ^^^^^^^^^^^^^^

In [None]:
import pandas as pd, os, re

OUTDIR = "/mnt/data/dengue_run"
tbl2024 = pd.read_csv(os.path.join(OUTDIR, "table_top15_2024_cases.csv"))
tbl2025 = pd.read_csv(os.path.join(OUTDIR, "table_top15_2025_cases.csv"))
tblcfr  = pd.read_csv(os.path.join(OUTDIR, "table_state_cfr_ranking.csv"))

def escape_latex(s):
    s = str(s)
    repl = {
        '&': r'\&',
        '%': r'\%',
        '$': r'\$',
        '#': r'\#',
        '_': r'\_',
        '{': r'\{',
        '}': r'\}',
        '~': r'\textasciitilde{}',
        '^': r'\textasciicircum{}',
        '\\': r'\textbackslash{}',
    }
    for k,v in repl.items():
        s = s.replace(k,v)
    return s

def fmt_int(x):
    try:
        return f"{int(round(x))}"
    except:
        return ""

def fmt_float1(x):
    try:
        return f"{x:,.1f}"
    except:
        return ""

def fmt_float4(x):
    try:
        return f"{x:.4f}"
    except:
        return ""

rows_2024 = []
for _, r in tbl2024.iterrows():
    st = escape_latex(r['state'])
    rows_2024.append(f"{st} & {fmt_int(r['cases'])} & {fmt_float1(r['mu_cases_hat'])} & {fmt_int(r['deaths'])} \\\\")

rows_2025 = []
for _, r in tbl2025.iterrows():
    st = escape_latex(r['state'])
    rows_2025.append(f"{st} & {fmt_int(r['cases'])} & {fmt_float1(r['mu_cases_hat'])} \\\\")

rows_cfr = []
for _, r in tblcfr.iterrows():
    st = escape_latex(r['state'])
    rows_cfr.append(f"{st} & {fmt_float4(r['observed_cfr'])} & {fmt_float4(r['predicted_cfr'])} & {fmt_int(r['avg_cases'])} \\\\")

len(rows_2024), len(rows_2025), len(rows_cfr), rows_2025


In [None]:
# Build manual LaTeX longtables from the CSVs and save as .tex files for \input{}.
import pandas as pd, numpy as np, os, re

def escape_latex(s: str) -> str:
    s = str(s)
    repl = {
        '\\': r'\textbackslash{}',
        '&': r'\&', '%': r'\%', '$': r'\$', '#': r'\#',
        '_': r'\_', '{': r'\{', '}': r'\}',
        '~': r'\textasciitilde{}', '^': r'\textasciicircum{}',
    }
    for k,v in repl.items():
        s = s.replace(k,v)
    return s

def fmt_int(x):
    try:
        if pd.isna(x): return ''
        return f"{int(round(float(x)))}"
    except Exception:
        return str(x)

def fmt_1(x):
    try:
        if pd.isna(x): return ''
        return f"{float(x):,.1f}"
    except Exception:
        return str(x)

def fmt_4(x):
    try:
        if pd.isna(x): return ''
        return f"{float(x):.4f}"
    except Exception:
        return str(x)

# Try preferred paths from the recent run; fallback to root if not found
base_candidates = ["/mnt/data/dengue_run", "/mnt/data"]
def find_csv(relname):
    for base in base_candidates:
        p = os.path.join(base, relname)
        if os.path.exists(p):
            return p
    raise FileNotFoundError(relname)

# Load CSVs
panel_csv = find_csv("panel_with_predictions.csv")
cases_csv = find_csv("cases_predictions.csv")
deaths_csv= find_csv("deaths_predictions.csv")

panel = pd.read_csv(panel_csv)
cases = pd.read_csv(cases_csv)
deaths= pd.read_csv(deaths_csv)

# Subset columns for panel
panel_cols = []
for c in ["state","year","cases","mu_cases_hat","deaths","pi_hat","mu_deaths_hat"]:
    if c in panel.columns:
        panel_cols.append(c)
panel_sub = panel[panel_cols].copy()

# Formatting functions map
fmt_map = {
    "state": escape_latex,
    "year": fmt_int,
    "cases": fmt_int,
    "deaths": fmt_int,
    "mu_cases_hat": fmt_1,
    "mu_deaths_hat": fmt_1,
    "pi_hat": fmt_4,
}

def df_to_longtable(df: pd.DataFrame, col_order, col_headers, fmts, label, caption):
    # Build header
    colspec = "l" + "r"*(len(col_order)-1)
    lines = []
    lines.append("\\begin{longtable}{" + colspec + "}")
    lines.append("\\caption{" + caption + "}\\\\")
    lines.append("\\label{" + label + "}\\\\")
    lines.append("\\toprule")
    lines.append(" & ".join(col_headers) + " \\\\")
    lines.append("\\midrule")
    lines.append("\\endfirsthead")
    lines.append("\\toprule")
    lines.append(" & ".join(col_headers) + " \\\\")
    lines.append("\\midrule")
    lines.append("\\endhead")
    lines.append("\\midrule")
    lines.append("\\multicolumn{" + str(len(col_order)) + "}{r}{\\emph{Continued on next page}}\\\\")
    lines.append("\\bottomrule")
    lines.append("\\endfoot")
    lines.append("\\bottomrule")
    lines.append("\\endlastfoot")
    # Rows
    for _, r in df.iterrows():
        vals = []
        for c in col_order:
            v = r.get(c, "")
            fmt = fmts.get(c, escape_latex)
            vals.append(fmt(v))
        lines.append(" & ".join(vals) + " \\\\")
    lines.append("\\end{longtable}")
    return "\n".join(lines)

# Ensure output dir
outdir = "/mnt/data/latex_tables"
os.makedirs(outdir, exist_ok=True)

# 1) panel_with_predictions_table.tex
panel_order = panel_cols
panel_headers = ["State/UT","Year","Cases","Pred. Cases","Deaths","Pred. CFR","Pred. Deaths"]
panel_tex = df_to_longtable(panel_sub, panel_order, panel_headers, fmt_map,
                            "tab:panel-with-preds-manual",
                            "State--year panel with observed counts and model predictions (manual longtable).")
with open(os.path.join(outdir, "panel_with_predictions_table.tex"), "w") as f:
    f.write(panel_tex)

# 2) cases_predictions_table.tex
cases_cols = []
for c in ["state","year","cases","deaths","mu_cases_hat"]:
    if c in cases.columns:
        cases_cols.append(c)
cases_headers = ["State/UT","Year","Cases","Deaths","Pred. Cases"]
cases_tex = df_to_longtable(cases[cases_cols], cases_cols, cases_headers, fmt_map,
                            "tab:cases-preds-manual",
                            "Cases model outputs by state--year (manual longtable).")
with open(os.path.join(outdir, "cases_predictions_table.tex"), "w") as f:
    f.write(cases_tex)

# 3) deaths_predictions_table.tex
deaths_cols = []
for c in ["state","year","cases","deaths","pi_hat","mu_deaths_hat"]:
    if c in deaths.columns:
        deaths_cols.append(c)
deaths_headers = ["State/UT","Year","Cases","Deaths","Pred. CFR","Pred. Deaths"]
deaths_tex = df_to_longtable(deaths[deaths_cols], deaths_cols, deaths_headers, fmt_map,
                             "tab:deaths-preds-manual",
                             "Deaths|cases model outputs (manual longtable).")
with open(os.path.join(outdir, "deaths_predictions_table.tex"), "w") as f:
    f.write(deaths_tex)

# Return paths for user to \input or download
[outdir + "/panel_with_predictions_table.tex",
 outdir + "/cases_predictions_table.tex",
 outdir + "/deaths_predictions_table.tex"]
