In [None]:
# ===== Birth/Fertility by Age: ARIMA for ALL AGE COLUMNS → CSVs + Plots (to 2043) =====
# Copy–paste and run. Adjust SHEET_NAME or RESULTS_DIR if you want.

import re, warnings
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from statsmodels.tsa.arima.model import ARIMA


warnings.filterwarnings("ignore")

# ─────────────────────────── CONFIG ───────────────────────────
FILE_PATH   = Path(r"F:\medical projects\project i\data\LIVE birth research\GRAPH\Race year Age GRAPH.xlsx")
SHEET_NAME  = None   # set to "BLACK 2" if you want to force that sheet; otherwise auto-pick
RESULTS_DIR = Path(r"D:\arima_birth_forecasts\age")
END_YEAR    = 2043
XAXIS_RIGHT = 2045

RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# ───────────────────── PLOTTING STYLE ─────────────────────
plt.rcParams.update({
    "figure.figsize": (13.5, 8.0),
    "figure.dpi": 160,
    "axes.grid": True,
    "grid.alpha": 0.25,
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelsize": 18,
    "axes.labelweight": "bold",
    "axes.titlesize": 22,
    "axes.titleweight": "bold",
    "legend.fontsize": 12,
    "legend.frameon": True,
    "legend.title_fontsize": 13,
    "lines.linewidth": 2.7,
    "lines.markersize": 6.7,
})
def style_axes(ax):
    ax.tick_params(axis="both", labelsize=15, width=2.0, length=7)
    for s in ax.spines.values():
        s.set_linewidth(2.2)

# ─────────────────────────── HELPERS ───────────────────────────
def normalize_cols(cols):
    idx = pd.Index(cols)
    idx = idx.map(lambda x: re.sub(r'[\u2010-\u2015]', '-', str(x)))  # unicode dashes → '-'
    idx = idx.str.replace('\u00a0', ' ', regex=False).str.strip()     # NBSP → space
    return idx

def flatten_columns(df: pd.DataFrame) -> pd.DataFrame:
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = [
            " ".join([str(x) for x in tup if str(x) and str(x) != "nan"]).strip()
            for tup in df.columns.to_list()
        ]
    df.columns = normalize_cols(df.columns)
    # uniquify duplicates
    seen, newcols = {}, []
    for c in df.columns:
        if c in seen:
            seen[c]+=1; newcols.append(f"{c}.{seen[c]}")
        else:
            seen[c]=0; newcols.append(c)
    df.columns = newcols
    return df

def choose_sheet(xls):
    names = xls.sheet_names
    def score(n):
        s = (n or "").lower().strip()
        sc = 0
        if "black" in s: sc += 3
        if "age"   in s: sc += 3
        if "graph" in s or "data" in s: sc += 2
        return sc
    return sorted(names, key=lambda nm: (-score(nm), names.index(nm)))[0]

def find_year_col(df):
    cand = [c for c in df.columns if re.search(r'^\s*year\b', str(c).lower())]
    if not cand:
        raise KeyError("Could not find a YEAR column.")
    return cand[0]

# Match '10-14', '10–14', plus optional 'year(s)' suffix
AGE_RX = re.compile(r'\b(\d{1,2})\s*[-–]\s*(\d{1,2})\b.*?(year|yrs|yea|y)?', flags=re.IGNORECASE)

def detect_age_cols(df, year_col):
    age_cols = []
    for c in df.columns:
        if c == year_col: 
            continue
        if AGE_RX.search(str(c)):
            age_cols.append(c)
    # order by starting age
    def start_age(c):
        m = AGE_RX.search(str(c))
        return int(m.group(1)) if m else 999
    age_cols.sort(key=start_age)
    if not age_cols:
        raise KeyError("No age-group columns detected (e.g., '15–19').")
    return age_cols

def sanitize_series(y: pd.Series):
    y = y.sort_index()
    full = pd.Index(range(int(y.index.min()), int(y.index.max())+1), name=y.index.name)
    y = y.reindex(full)
    if y.isna().any():
        y = y.interpolate(limit_direction="both")
    return y

def select_arima_order(y):
    best = None
    for d in [0,1,2]:
        for p in range(4):
            for q in range(4):
                if (p,d,q)==(0,0,0): 
                    continue
                try:
                    trend = "n" if d>0 else "c"
                    res = ARIMA(y, order=(p,d,q), trend=trend,
                                enforce_stationarity=False, enforce_invertibility=False
                               ).fit(method_kwargs={"warn_convergence":False})
                    aic = res.aic
                    if (best is None) or (aic < best[0]):
                        best = (aic, (p,d,q,trend), res)
                except Exception:
                    pass
    if best is None:
        res = ARIMA(y, order=(1,1,0), trend="n",
                    enforce_stationarity=False, enforce_invertibility=False
                   ).fit(method_kwargs={"warn_convergence":False})
        return (1,1,0,"n"), res
    return best[1], best[2]

def forecast_to(y, end_year, conf=0.95):
    last_year = int(y.index.max())
    steps = max(0, end_year - last_year)
    if steps == 0:
        raise ValueError("Observed already reaches END_YEAR.")
    order, res = select_arima_order(y)
    fc  = res.get_forecast(steps=steps)
    ci  = fc.conf_int(alpha=1-conf)
    yrs = list(range(last_year+1, end_year+1))
    return order, pd.DataFrame({
        "Year": yrs,
        "Point.Forecast": fc.predicted_mean.values,
        "Lo.95": ci.iloc[:,0].values,
        "Hi.95": ci.iloc[:,1].values,
        "Order": [f"{order[0]},{order[1]},{order[2]} ({order[3]})"]*steps
    })

def safe_name(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_-]+", "_", str(s)).strip("_")

# ─────────────────────────── LOAD ───────────────────────────
xls = pd.ExcelFile(FILE_PATH)
sheet = SHEET_NAME or choose_sheet(xls)
df    = pd.read_excel(xls, sheet_name=sheet, engine="openpyxl")
df    = flatten_columns(df)

year_col = find_year_col(df)
age_cols = detect_age_cols(df, year_col)

df[year_col] = pd.to_numeric(df[year_col], errors="coerce")
for c in age_cols:
    df[c] = pd.to_numeric(df[c], errors="coerce")

print(f"\nUsing sheet: {sheet}")
print(f"YEAR column: {year_col}")
print(f"Detected age columns ({len(age_cols)}): {age_cols[:8]}{' ...' if len(age_cols)>8 else ''}")

last_obs_year = int(pd.to_numeric(df[year_col], errors="coerce").dropna().max())

# ─────────────────────────── FIT + SAVE ───────────────────────────
all_rows = []
for col in age_cols:
    sub = df[[year_col, col]].dropna()
    sub = sub.rename(columns={year_col:"Year", col:"Rate"}).sort_values("Year")
    if sub.empty:
        continue

    y = pd.Series(sub["Rate"].values, index=sub["Year"].astype(int), name="Rate")
    y = sanitize_series(y)

    order, fc = forecast_to(y, end_year=END_YEAR, conf=0.95)

    # CSV per age
    csv_path = RESULTS_DIR / f"{safe_name(col)}_forecast_to_{END_YEAR}.csv"
    out = fc.copy()
    for cc in ["Point.Forecast","Lo.95","Hi.95"]:
        out[cc] = out[cc].round(2)
    out.insert(0, "Age_Group", col)
    out.to_csv(csv_path, index=False)
    print(f"[{col}] ARIMA order={order} → CSV: {csv_path}")

    # plot per age
    fig, ax = plt.subplots()
    ax.plot(y.index, y.values, marker="o", label=f"{col} (obs)")
    ln, = ax.plot(fc["Year"], fc["Point.Forecast"], linestyle="--", marker="o", label=f"{col} (fc)")
    lo = np.maximum(fc["Lo.95"].values, 0.0)  # rates ≥ 0
    ax.fill_between(fc["Year"], lo, fc["Hi.95"].values, alpha=0.15, color=ln.get_color(), label="95% CI (fc)")
    ax.axvline(x=last_obs_year+0.5, linestyle=":", linewidth=2.2)
    # extend x-axis to 2045 (or whatever XAXIS_RIGHT is)
    min_year = int(df[year_col].min())
    ax.set_xlim(min_year, XAXIS_RIGHT)
    ax.set_xticks(np.arange(min_year, XAXIS_RIGHT + 1, 5))  # ticks every 5 years (optional)


    ax.set_title(f"{col} — observed (≤{last_obs_year}) forecast ({last_obs_year+1}–{END_YEAR})")
    ax.set_xlabel("Year"); ax.set_ylabel("Birth/Fertility Rate (per 1,000)")
    ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
    style_axes(ax)
    ax.legend(ncols=2, bbox_to_anchor=(0.5, -0.15), loc="upper center")
    plt.tight_layout(rect=[0, 0.07, 1, 1])
    png_path = RESULTS_DIR / f"{safe_name(col)}_timeseries_to_{END_YEAR}.png"
    plt.savefig(png_path, dpi=300, bbox_inches="tight"); plt.close()
    print("   Plot:", png_path)

    # keep for consolidated
    tmp = out.copy(); tmp["Age_Group"] = col
    all_rows.append(tmp)

# consolidated CSV for all ages
if all_rows:
    all_df = pd.concat(all_rows, ignore_index=True)
    all_df.to_csv(RESULTS_DIR / f"ALL_AGE_FORECASTS_to_{END_YEAR}.csv", index=False)
    print("Consolidated CSV:", RESULTS_DIR / f"ALL_AGE_FORECASTS_to_{END_YEAR}.csv")

# ───────────────────── COMBINED PLOT (legend below) ─────────────────────
fig, ax = plt.subplots()

# observed lines
for col in age_cols:
    sub = df[[year_col, col]].dropna().rename(columns={year_col:"Year", col:"Rate"}).sort_values("Year")
    if not sub.empty:
        ax.plot(sub["Year"], sub["Rate"], marker="o", label=f"{col} (obs)")

# forecast lines + one shared CI for context
added_ci = False
for col in age_cols:
    fc_path = RESULTS_DIR / f"{safe_name(col)}_forecast_to_{END_YEAR}.csv"
    if fc_path.exists():
        fc = pd.read_csv(fc_path)
        ln, = ax.plot(fc["Year"], fc["Point.Forecast"], linestyle="--", marker="o", label=f"{col} (fc)")
        lo = np.maximum(fc["Lo.95"].values, 0.0)
        if not added_ci:
            ax.fill_between(fc["Year"], lo, fc["Hi.95"].values, alpha=0.12, color=ln.get_color(), label="95% CI (fc)")
            added_ci = True
        else:
            ax.fill_between(fc["Year"], lo, fc["Hi.95"].values, alpha=0.12, color=ln.get_color())

ax.axvline(x=last_obs_year+0.5, linestyle=":", linewidth=2.2)
min_year = int(df[year_col].min())
ax.set_xlim(min_year, XAXIS_RIGHT)
ax.set_xticks(np.arange(min_year, XAXIS_RIGHT + 1, 5))  # optional

ax.set_title(f"Observed (≤{last_obs_year}) + forecasts ({last_obs_year+1}–{END_YEAR})")
ax.set_xlabel("Year"); ax.set_ylabel("Birth/Fertility Rate (per 1,000)")
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
style_axes(ax)

# legend below (clean, centered)
ax.legend(ncols=3, bbox_to_anchor=(0.5, -0.15), loc="upper center")
plt.tight_layout(rect=[0, 0.07, 1, 1])
combined_png = RESULTS_DIR / f"COMBINED_AGE_timeseries_to_{END_YEAR}.png"
plt.savefig(combined_png, dpi=300, bbox_inches="tight"); plt.close()
print("Combined plot:", combined_png)

print("\n✅ Done → Results saved in:", RESULTS_DIR)



Using sheet: BLACK
YEAR column: YEAR
Detected age columns (8): ['10-14 years', '15-19 years', '20-24 years', '25-29 years', '30-34 years', '35-39 years', '40-44 years', '45-54 years']
[10-14 years] ARIMA order=(3, 0, 1, 'c') → CSV: D:\arima_birth_forecasts\age\10-14_years_forecast_to_2043.csv
   Plot: D:\arima_birth_forecasts\age\10-14_years_timeseries_to_2043.png
[15-19 years] ARIMA order=(2, 2, 3, 'n') → CSV: D:\arima_birth_forecasts\age\15-19_years_forecast_to_2043.csv
   Plot: D:\arima_birth_forecasts\age\15-19_years_timeseries_to_2043.png
[20-24 years] ARIMA order=(2, 2, 3, 'n') → CSV: D:\arima_birth_forecasts\age\20-24_years_forecast_to_2043.csv
   Plot: D:\arima_birth_forecasts\age\20-24_years_timeseries_to_2043.png
[25-29 years] ARIMA order=(2, 1, 3, 'n') → CSV: D:\arima_birth_forecasts\age\25-29_years_forecast_to_2043.csv
   Plot: D:\arima_birth_forecasts\age\25-29_years_timeseries_to_2043.png
[30-34 years] ARIMA order=(0, 2, 3, 'n') → CSV: D:\arima_birth_forecasts\age\30-34_