In [1]:
%cd ..

c:\Users\esper\Desktop\injury_forecasting


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import pandas as pd

In [5]:
from src.aggregation import (
    aggregate_panel,
    check_panel_balance
)

from src.config import (
    DATA_DIR,
    RESULTS_DIR,
    FIGURES_DIR
)

from src.plotting import (
    set_plot_style
)

set_plot_style()

In [6]:
df = pd.read_csv(
    DATA_DIR / "processed" / "federal_df.csv", 
    parse_dates=['EventDate'], 
    low_memory=False)

df.shape

(96393, 27)

In [7]:
monthly_panel = aggregate_panel(
    df,
    date_col='EventDate',
    group_col='State',
    target_cols=("Hospitalized",),
    freq="MS",
    agg='sum',
    complete_panel=True
)

monthly_panel.head(), monthly_panel.shape

(     State       Date  Hospitalized
 0  ALABAMA 2015-01-01          14.0
 1  ALABAMA 2015-02-01          15.0
 2  ALABAMA 2015-03-01          22.0
 3  ALABAMA 2015-04-01          21.0
 4  ALABAMA 2015-05-01          23.0,
 (3660, 3))

In [8]:
from src.features import build_panel_features

X, y, meta = build_panel_features(
    panel_df=monthly_panel,
    target="Hospitalized",
    group_col="State",
    date_col="Date",
    freq="MS",
    add_calendar=True,
    add_lags=True,
    add_rolling=True,
    add_ewm=True,
    lags=(1, 2, 3, 6, 12),
    rolling_windows=(3, 6, 12),
    ewm_spans=(3, 6, 12),
    dropna=True,
)

X.head(), y.head(), meta.head()


(   year  month  quarter  weekofyear  Hospitalized_lag1  Hospitalized_lag2  \
 0  2016      1        1          53               12.0               18.0   
 1  2016      2        1           5               29.0               12.0   
 2  2016      3        1           9               17.0               29.0   
 3  2016      4        2          13               17.0               17.0   
 4  2016      5        2          17               21.0               17.0   
 
    Hospitalized_lag3  Hospitalized_lag6  Hospitalized_lag12  \
 0               26.0               25.0                14.0   
 1               18.0               23.0                15.0   
 2               12.0               15.0                22.0   
 3               29.0               26.0                21.0   
 4               17.0               18.0                23.0   
 
    Hospitalized_rollmean3  Hospitalized_rollmean6  Hospitalized_rollmean12  \
 0               18.666667               19.833333               

In [9]:
from src.splitting import temporal_panel_split

splits = temporal_panel_split(
    X=X,
    y=y,
    meta=meta,
    date_col='Date',
    train_end='2023-12-01',
    test_start='2024-01-01',
    test_size=12
)


In [10]:

state_col = "State"
date_col = "Date"

X_train = splits["train"]["X"]
y_train = splits["train"]["y"]
meta_train = splits["train"]["meta"]

X_test = splits["test"]["X"]
y_test = splits["test"]["y"]
meta_test = splits["test"]["meta"]

# Sanity checks
assert len(X_train) == len(y_train) == len(meta_train)
assert len(X_test) == len(y_test) == len(meta_test)
assert state_col in meta_train.columns and state_col in meta_test.columns
assert date_col in meta_train.columns and date_col in meta_test.columns

splits_by_state = {}

all_states = sorted(set(meta_train[state_col].unique()) | set(meta_test[state_col].unique()))

for st in all_states:
    tr_mask = meta_train[state_col].eq(st).to_numpy()
    te_mask = meta_test[state_col].eq(st).to_numpy()

    # Some states might be missing from train/test if data is sparse after lagging
    if tr_mask.sum() == 0 or te_mask.sum() == 0:
        continue

    splits_by_state[st] = {
        "train": {
            "X": X_train.loc[tr_mask].reset_index(drop=True),
            "y": y_train.loc[tr_mask].reset_index(drop=True),
            "meta": meta_train.loc[tr_mask].reset_index(drop=True),
        },
        "test": {
            "X": X_test.loc[te_mask].reset_index(drop=True),
            "y": y_test.loc[te_mask].reset_index(drop=True),
            "meta": meta_test.loc[te_mask].reset_index(drop=True),
        },
    }

len(splits_by_state), list(splits_by_state.keys())[:10]


(30,
 ['ALABAMA',
  'ARKANSAS',
  'COLORADO',
  'CONNECTICUT',
  'DELAWARE',
  'DISTRICT OF COLUMBIA',
  'FLORIDA',
  'GEORGIA',
  'IDAHO',
  'ILLINOIS'])

In [11]:
test_counts = pd.Series({st: len(d["test"]["y"]) for st, d in splits_by_state.items()})
test_counts.describe(), test_counts.value_counts().head()


(count    30.0
 mean     12.0
 std       0.0
 min      12.0
 25%      12.0
 50%      12.0
 75%      12.0
 max      12.0
 dtype: float64,
 12    30
 Name: count, dtype: int64)

In [12]:
from src.models import get_model_configs, instantiate_models
from sklearn.base import clone



In [13]:
import numpy as np

In [14]:
train_sizes = [
    len(s["train"]["y"])
    for s in splits_by_state.values()
]

n_samples_local = int(np.median(train_sizes))

model_configs = get_model_configs(
    n_samples=n_samples_local,
    use_linear=True,
    use_tree=True,
    random_state=0,
)

base_models = instantiate_models(model_configs)

list(base_models.keys())


['Ridge', 'Lasso', 'ElasticNet', 'PLS', 'XGBoost', 'LightGBM', 'CatBoost']

In [15]:
def rmse(y_true, y_pred):
    return float(np.sqrt(np.mean((y_pred - y_true) ** 2)))

def mae(y_true, y_pred):
    return float(np.mean(np.abs(y_pred - y_true)))


local_fitted = {}     # (model, state) -> fitted estimator
pred_rows = []
metric_rows = []

state_col = "State"
date_col = "Date"

for state, s in splits_by_state.items():
    Xtr, ytr, mtr = s["train"]["X"], s["train"]["y"], s["train"]["meta"]
    Xte, yte, mte = s["test"]["X"],  s["test"]["y"],  s["test"]["meta"]

    dates_tr = pd.to_datetime(mtr[date_col])
    dates_te = pd.to_datetime(mte[date_col])

    for model_name, base_model in base_models.items():
        model = clone(base_model)  # critical: fresh model per state
        model.fit(Xtr, ytr)

        local_fitted[(model_name, state)] = model

        yhat_tr = model.predict(Xtr)
        yhat_te = model.predict(Xte)

        # store predictions (train)
        pred_rows.append(pd.DataFrame({
            "model": model_name,
            "State": state,
            "Date": dates_tr.values,
            "split": "train",
            "y_true": ytr.values,
            "y_pred": yhat_tr,
        }))

        # store predictions (test)
        pred_rows.append(pd.DataFrame({
            "model": model_name,
            "State": state,
            "Date": dates_te.values,
            "split": "test",
            "y_true": yte.values,
            "y_pred": yhat_te,
        }))

        # per-state metrics (test)
        metric_rows.append({
            "model": model_name,
            "State": state,
            "rmse": rmse(yte.values, yhat_te),
            "mae": mae(yte.values, yhat_te),
            "n_test": len(yte),
        })


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

In [16]:
preds_df = pd.concat(pred_rows, ignore_index=True)

metrics_df = (
    pd.DataFrame(metric_rows)
    .sort_values(["model", "State"])
    .reset_index(drop=True)
)

overall_df = (
    preds_df[preds_df["split"] == "test"]
    .groupby("model", as_index=False)
    .apply(lambda g: pd.Series({
        "rmse": rmse(g["y_true"].values, g["y_pred"].values),
        "mae": mae(g["y_true"].values, g["y_pred"].values),
        "n_obs": len(g),
    }))
    .reset_index(drop=True)
)

overall_df.sort_values("rmse")


  .apply(lambda g: pd.Series({


Unnamed: 0,model,rmse,mae,n_obs
4,PLS,5.753669,4.077606,360.0
5,Ridge,5.802805,4.104763,360.0
1,ElasticNet,5.848323,4.142362,360.0
2,Lasso,5.855713,4.146899,360.0
0,CatBoost,6.031707,4.202812,360.0
6,XGBoost,6.205807,4.303937,360.0
3,LightGBM,6.339373,4.443232,360.0


In [17]:
from statsmodels.tsa.holtwinters import ExponentialSmoothing

ets_fitted = {}     # state -> fitted ETS model
ets_pred_rows = []
ets_metric_rows = []

date_col = "Date"

for state, s in splits_by_state.items():
    ytr = s["train"]["y"].astype(float).reset_index(drop=True)
    yte = s["test"]["y"].astype(float).reset_index(drop=True)

    dtr = pd.to_datetime(s["train"]["meta"][date_col]).reset_index(drop=True)
    dte = pd.to_datetime(s["test"]["meta"][date_col]).reset_index(drop=True)

    # Build proper time-indexed series for statsmodels
    ytr_ts = pd.Series(ytr.values, index=dtr, name="y_train").sort_index()
    yte_ts = pd.Series(yte.values, index=dte, name="y_test").sort_index()

    # --- Choose a simple ETS spec for monthly ---
    # Keep it conservative to avoid instability on short state series.
    # If you *know* you have enough data per state, you can switch on seasonality.
    use_seasonal = (len(ytr_ts) >= 36)  # heuristic: 3 years of monthly data
    seasonal = "add" if use_seasonal else None
    seasonal_periods = 12 if use_seasonal else None

    try:
        model = ExponentialSmoothing(
            ytr_ts,
            trend="add",
            seasonal=seasonal,
            seasonal_periods=seasonal_periods,
            initialization_method="estimated",
        ).fit(optimized=True)
    except Exception as e:
        # Fallback: no seasonality if seasonal fit fails
        model = ExponentialSmoothing(
            ytr_ts,
            trend="add",
            seasonal=None,
            initialization_method="estimated",
        ).fit(optimized=True)

    ets_fitted[state] = model

    # In-sample fitted values (train) aligned to train dates
    yhat_tr = model.fittedvalues.reindex(ytr_ts.index)

    # Out-of-sample forecast for the test horizon (length of yte)
    yhat_te = model.forecast(len(yte_ts))
    yhat_te.index = yte_ts.index  # align to test dates

    # Store predictions (train)
    ets_pred_rows.append(pd.DataFrame({
        "model": "ETS",
        "State": state,
        "Date": ytr_ts.index.values,
        "split": "train",
        "y_true": ytr_ts.values,
        "y_pred": yhat_tr.values,
    }))

    # Store predictions (test)
    ets_pred_rows.append(pd.DataFrame({
        "model": "ETS",
        "State": state,
        "Date": yte_ts.index.values,
        "split": "test",
        "y_true": yte_ts.values,
        "y_pred": yhat_te.values,
    }))

    # Metrics on test
    ets_metric_rows.append({
        "model": "ETS",
        "State": state,
        "rmse": rmse(yte_ts.values, yhat_te.values),
        "mae": mae(yte_ts.values, yhat_te.values),
        "n_test": len(yte_ts),
        "used_seasonal": bool(use_seasonal),
    })

ets_preds_df = pd.concat(ets_pred_rows, ignore_index=True)
ets_metrics_df = pd.DataFrame(ets_metric_rows).sort_values("State").reset_index(drop=True)

# Overall ETS pooled metrics on test
ets_overall = (
    ets_preds_df[ets_preds_df["split"] == "test"]
    .pipe(lambda g: pd.Series({
        "rmse": rmse(g["y_true"].values, g["y_pred"].values),
        "mae": mae(g["y_true"].values, g["y_pred"].values),
        "n_obs": len(g),
    }))
)

ets_overall, ets_metrics_df.head()

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


(rmse       5.243713
 mae        3.744875
 n_obs    360.000000
 dtype: float64,
   model        State      rmse       mae  n_test  used_seasonal
 0   ETS      ALABAMA  5.690450  4.069811      12           True
 1   ETS     ARKANSAS  3.860104  2.672618      12           True
 2   ETS     COLORADO  4.290615  3.413713      12           True
 3   ETS  CONNECTICUT  3.627967  3.165865      12           True
 4   ETS     DELAWARE  1.601489  1.139918      12           True)