
# Birth Rate & Fertility in Developed Countries — EDA + Forecasting (Advanced) 📊

This notebook expands the analysis to include **fertility rate** and **covariates**, and compares three models:
- **SARIMAX** (with optional exogenous regressors)
- **Exponential Smoothing (ETS)**
- **Prophet** (with optional regressors)

**Indicators (World Bank):**
- `SP.DYN.CBRT.IN` → Crude birth rate (per 1,000 people) **(target)**
- `SP.DYN.TFRT.IN` → Fertility rate, total (births per woman)
- `NY.GDP.PCAP.CD` → GDP per capita (current US$)
- `SL.TLF.CACT.FE.ZS` → Female labor force participation rate (% ages 15+)

> Tip: If the World Bank API is slow, run once and cache CSVs in your repo.


## 0) (Optional) Install dependencies

In [None]:

# If needed, uncomment:
# %pip install -q pandas pandas-datareader matplotlib statsmodels prophet wbdata


## 1) Imports & configuration

In [None]:

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Fetchers
try:
    from pandas_datareader import wb as wbreader
    HAVE_WB_READER = True
except Exception:
    HAVE_WB_READER = False

try:
    import wbdata
    HAVE_WBDATA = True
except Exception:
    HAVE_WBDATA = False

# Models
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tools.sm_exceptions import ConvergenceWarning

try:
    from prophet import Prophet
    HAVE_PROPHET = True
except Exception:
    HAVE_PROPHET = False

plt.rcParams["figure.figsize"] = (11, 6)
plt.rcParams["axes.grid"] = True

IND_MAP = {
    "SP.DYN.CBRT.IN": "birth_rate",
    "SP.DYN.TFRT.IN": "fertility_rate",
    "NY.GDP.PCAP.CD": "gdp_per_capita_usd",
    "SL.TLF.CACT.FE.ZS": "female_lfp"
}
INDICATORS = list(IND_MAP.keys())


## 2) Countries

In [None]:

COUNTRIES = [
    "USA","CAN","GBR","DEU","FRA","ITA","ESP","NLD","BEL","LUX","IRL",
    "SWE","NOR","DNK","FIN","ISL","CHE","AUT","PRT","GRC",
    "JPN","KOR","AUS","NZL","CZE","SVK","SVN","POL","HUN","EST","LVA","LTU"
]
print(f"Selected countries ({len(COUNTRIES)}):", ", ".join(COUNTRIES))


## 3) Fetch World Bank indicators

In [None]:

def fetch_wb_panel(countries, indicators, start=1960, end=None):
    if end is None:
        end = pd.Timestamp.today().year - 1

    if HAVE_WB_READER:
        frames = []
        for ind in indicators:
            df = wbreader.download(indicator=ind, country=countries, start=start, end=end)
            df = df.reset_index().rename(columns={ind: IND_MAP[ind]})
            df["date"] = pd.to_datetime(df["year"].astype(int), format="%Y")
            df = df[["country","date", IND_MAP[ind]]]
            frames.append(df)
        out = frames[0]
        for f in frames[1:]:
            out = out.merge(f, on=["country","date"], how="outer")
        return out.sort_values(["country","date"]).reset_index(drop=True)

    if HAVE_WBDATA:
        frames = []
        for ind in indicators:
            df = wbdata.get_dataframe({ind: IND_MAP[ind]}, country=countries, convert_date=True).reset_index()
            df = df.rename(columns={ind: IND_MAP[ind]})
            df = df[["country","date", IND_MAP[ind]]]
            frames.append(df)
        out = frames[0]
        for f in frames[1:]:
            out = out.merge(f, on=["country","date"], how="outer")
        return out.sort_values(["country","date"]).reset_index(drop=True)

    raise ImportError("Install pandas-datareader or wbdata to fetch World Bank data.")

panel = fetch_wb_panel(COUNTRIES, INDICATORS, start=1960)
panel.head()


## 4) Clean & quick overview

In [None]:

df = panel.copy()
for col in ["birth_rate","fertility_rate","gdp_per_capita_usd","female_lfp"]:
    df[col] = pd.to_numeric(df[col], errors="coerce")

# Drop if both target & fertility are missing
df = df.dropna(subset=["birth_rate","fertility_rate"], how="all")

pivot_birth = df.pivot(index="date", columns="country", values="birth_rate").sort_index()
pivot_fert  = df.pivot(index="date", columns="country", values="fertility_rate").sort_index()

# Align them to common timeframe
common_idx = pivot_birth.index.intersection(pivot_fert.index)
pivot_birth = pivot_birth.loc[common_idx]
pivot_fert = pivot_fert.loc[common_idx]

ax = pivot_birth.plot(title="Crude Birth Rate per 1,000 people — World Bank", xlabel="Year", ylabel="Birth rate")
_ = ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left")
plt.show()

ax = pivot_fert.plot(title="Fertility Rate (births per woman) — World Bank", xlabel="Year", ylabel="Fertility rate")
_ = ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left")
plt.show()


### 4.1) Cross-sectional correlations (latest common year)

In [None]:

# Latest year with decent coverage
latest_year = int(common_idx.max().year) if len(common_idx) else None
print("Latest common year:", latest_year)

latest = df[df["date"].dt.year == latest_year].copy() if latest_year else pd.DataFrame()
if not latest.empty:
    fig, axs = plt.subplots(1, 2, figsize=(14,5))
    axs[0].scatter(latest["gdp_per_capita_usd"], latest["birth_rate"])
    axs[0].set_title("Birth rate vs GDP per capita (latest year)")
    axs[0].set_xlabel("GDP per capita (USD)")
    axs[0].set_ylabel("Birth rate")

    axs[1].scatter(latest["female_lfp"], latest["birth_rate"])
    axs[1].set_title("Birth rate vs Female LFP (latest year)")
    axs[1].set_xlabel("Female LFP (%)")
    axs[1].set_ylabel("Birth rate")

    plt.tight_layout()
    plt.show()


## 5) Modeling helpers

In [None]:

from sklearn.metrics import mean_absolute_error, mean_squared_error

def train_test_split_series(series, test_years=10):
    if len(series) <= test_years + 5:
        test_years = max(1, len(series)//5)
    train = series.iloc[:-test_years]
    test = series.iloc[-test_years:]
    return train, test

def evaluate_forecast(actual, forecast):
    mae = mean_absolute_error(actual, forecast)
    rmse = mean_squared_error(actual, forecast, squared=False)
    return mae, rmse

def zscore(s):
    s = pd.Series(s).astype(float)
    return (s - s.mean()) / (s.std(ddof=0) if s.std(ddof=0) != 0 else 1.0)


## 6) Choose a focus country

In [None]:

FOCUS_COUNTRY = "United States"  # edit as desired

g = df[df["country"]==FOCUS_COUNTRY].dropna(subset=["birth_rate"]).sort_values("date").copy()
g = g.set_index("date")

y = g["birth_rate"].asfreq("Y")
# Build exog from fertility + covariates (optionally z-scored)
exog_cols = ["fertility_rate","gdp_per_capita_usd","female_lfp"]
X = g[exog_cols].copy()
X = X.apply(zscore).asfreq("Y")

y = y.dropna()
X = X.loc[y.index]

y.tail(), X.tail()


## 7) Fit & compare three models

In [None]:

TEST_YEARS = 10
train_y, test_y = train_test_split_series(y, TEST_YEARS)
train_X, test_X = X.loc[train_y.index], X.loc[test_y.index]

results = {}

# 7.1 SARIMAX (with exog)
best_aic, best_order, best_model = np.inf, None, None
for p in range(0,4):
    for d in range(0,3):
        for q in range(0,4):
            try:
                m = SARIMAX(train_y, exog=train_X, order=(p,d,q), enforce_stationarity=False, enforce_invertibility=False)
                r = m.fit(disp=False)
                if r.aic < best_aic:
                    best_aic, best_order, best_model = r.aic, (p,d,q), r
            except Exception:
                continue

if best_model is not None:
    sarimax_fc = best_model.get_forecast(steps=len(test_y), exog=test_X).predicted_mean
    mae, rmse = evaluate_forecast(test_y, sarimax_fc)
    results["SARIMAX"] = {"order": best_order, "mae": mae, "rmse": rmse, "fc": sarimax_fc}
else:
    results["SARIMAX"] = {"order": None, "mae": np.nan, "rmse": np.nan, "fc": pd.Series(index=test_y.index, dtype=float)}

# 7.2 ETS (no exog)
try:
    ets = ExponentialSmoothing(train_y, trend="add", seasonal=None, initialization_method="estimated").fit()
    ets_fc = ets.forecast(len(test_y))
    mae, rmse = evaluate_forecast(test_y, ets_fc)
    results["ETS"] = {"params": "trend=add", "mae": mae, "rmse": rmse, "fc": ets_fc}
except Exception:
    results["ETS"] = {"params": None, "mae": np.nan, "rmse": np.nan, "fc": pd.Series(index=test_y.index, dtype=float)}

# 7.3 Prophet (with exog if available)
if HAVE_PROPHET:
    try:
        train_df = pd.DataFrame({"ds": train_y.index.to_timestamp(), "y": train_y.values})
        test_df  = pd.DataFrame({"ds": test_y.index.to_timestamp(), "y": test_y.values})
        # Add regressors (z-scored)
        for c in exog_cols:
            train_df[c] = train_X[c].values
            test_df[c]  = test_X[c].values

        m = Prophet(yearly_seasonality=False, weekly_seasonality=False, daily_seasonality=False)
        for c in exog_cols:
            m.add_regressor(c)
        m.fit(train_df)
        future = test_df[["ds"] + exog_cols].copy()
        prophet_fc = m.predict(future)["yhat"].values
        prophet_fc = pd.Series(prophet_fc, index=test_y.index)
        mae, rmse = evaluate_forecast(test_y, prophet_fc)
        results["Prophet"] = {"mae": mae, "rmse": rmse, "fc": prophet_fc}
    except Exception:
        results["Prophet"] = {"mae": np.nan, "rmse": np.nan, "fc": pd.Series(index=test_y.index, dtype=float)}
else:
    results["Prophet"] = {"mae": np.nan, "rmse": np.nan, "fc": pd.Series(index=test_y.index, dtype=float)}

results


### 7.4) Plot test-period forecasts

In [None]:

plt.figure()
plt.plot(train_y.index, train_y.values, label="Train")
plt.plot(test_y.index,  test_y.values,  label="Test")
for name, res in results.items():
    fc = res["fc"]
    if len(fc) > 0:
        plt.plot(fc.index, fc.values, linestyle="--", marker="o", label=f"{name} forecast")
plt.title(f"{FOCUS_COUNTRY} — Test forecasts comparison")
plt.xlabel("Year"); plt.ylabel("Birth rate (per 1,000)")
plt.legend()
plt.show()

pd.DataFrame({k: {"MAE": v["mae"], "RMSE": v["rmse"]} for k,v in results.items()})


## 8) Refit best model on full data & forecast ahead

In [None]:

# pick best by RMSE
best_name = min(results.keys(), key=lambda k: results[k]["rmse"] if not np.isnan(results[k]["rmse"]) else np.inf)
print("Best model by RMSE:", best_name)

STEPS_AHEAD = 10

if best_name == "SARIMAX" and results["SARIMAX"]["order"] is not None:
    full_m = SARIMAX(y, exog=X, order=results["SARIMAX"]["order"], enforce_stationarity=False, enforce_invertibility=False).fit(disp=False)
    future_index = pd.period_range(y.index[-1]+1, periods=STEPS_AHEAD, freq="Y")
    # naive exog hold: use last observed values
    last_row = X.iloc[[-1]].values
    future_X = pd.DataFrame(np.repeat(last_row, STEPS_AHEAD, axis=0), index=future_index, columns=X.columns)
    fc = full_m.get_forecast(steps=STEPS_AHEAD, exog=future_X)
    fc_mean = fc.predicted_mean
    fc_ci = fc.conf_int(alpha=0.2)
elif best_name == "ETS":
    full_m = ExponentialSmoothing(y, trend="add", seasonal=None, initialization_method="estimated").fit()
    fc_mean = full_m.forecast(STEPS_AHEAD)
    fc_ci = None
elif best_name == "Prophet" and HAVE_PROPHET:
    m = Prophet(yearly_seasonality=False, weekly_seasonality=False, daily_seasonality=False)
    for c in ["fertility_rate","gdp_per_capita_usd","female_lfp"]:
        m.add_regressor(c)
    df_p = pd.DataFrame({"ds": y.index.to_timestamp(), "y": y.values})
    for c in X.columns:
        df_p[c] = X[c].values
    m.fit(df_p)
    future_dates = pd.period_range(y.index[-1]+1, periods=STEPS_AHEAD, freq="Y").to_timestamp()
    future = pd.DataFrame({"ds": future_dates})
    # naive regressors hold at last value
    last_vals = X.iloc[-1]
    for c in X.columns:
        future[c] = float(last_vals[c])
    fc_mean = m.predict(future).set_index(pd.PeriodIndex(future["ds"], freq="Y"))["yhat"]
    fc_ci = None
else:
    fc_mean, fc_ci = pd.Series(dtype=float), None

plt.figure()
plt.plot(y.index, y.values, label="Historical")
if len(fc_mean) > 0:
    plt.plot(fc_mean.index, fc_mean.values, linestyle="--", marker="o", label=f"{best_name} forecast (+{STEPS_AHEAD}y)")
plt.title(f"{FOCUS_COUNTRY} — {best_name} {STEPS_AHEAD}-year forecast")
plt.xlabel("Year"); plt.ylabel("Birth rate (per 1,000)")
plt.legend()
plt.show()


## 9) Save outputs

In [None]:

out_panel = df.reset_index(drop=True)
out_panel.to_csv("results_birth_fert_covariates_panel.csv", index=False)
pivot_birth.to_csv("results_birth_rate_pivot.csv")
pivot_fert.to_csv("results_fertility_rate_pivot.csv")
print("Saved: results_birth_fert_covariates_panel.csv, results_birth_rate_pivot.csv, results_fertility_rate_pivot.csv")
