In [None]:
"""
Reinforcement Learning for Monetary Policy Optimization
========================================================
A DQN agent trained on monthly US macroeconomic data (1971–2025) to set
interest rates, benchmarked against the Taylor Rule and actual Fed decisions.

Authors: Leonardo Luksic, Krisha Chandnani, Ignacio Orueta
LSE — February 2026

Pipeline overview:
  1. Load monthly FRED data (CPI, unemployment, capacity utilisation,
     fed funds rate, 10-year Treasury, NFCI)
  2. Engineer features with realistic central-bank information lags
     (18/24/30 months) and a 12-month-ahead inflation target
  3. Train neural-network inflation forecasters under three specifications,
     validated with expanding-window time-series cross-validation
  4. Build a Gymnasium environment where the agent's rate choice feeds
     back into the transition dynamics
  5. Train a DQN agent (PyTorch) and compare its policy against the
     standard Taylor Rule and actual Federal Reserve decisions
  6. Generate portfolio-ready visualisations

Key methodological choices:
  - Features use only information realistically available to a central bank
    (accounting for publication lags of ~6 months plus a 12-month lookback)
  - The forecast horizon is 12 months ahead, matching actual policy horizons
  - Time-series CV avoids data leakage across economic regimes
  - The environment propagates the agent's rate choice into future features
    via a rolling lag buffer, so actions have real consequences
"""

import os
import argparse
import warnings
import random
from collections import deque, namedtuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

import gymnasium as gym
from gymnasium import spaces
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

warnings.filterwarnings("ignore", message="X has feature names")
warnings.filterwarnings("ignore", category=FutureWarning)


# ============================================================================
# Configuration
# ============================================================================

DEFAULT_CFG = {
    "data_path": "./data",
    "output_path": "./outputs",

    # Targets
    "inflation_target": 2.0,
    "unemployment_natural": 5.0,

    # Environment
    "max_steps": 36,          # 3-year episodes
    "n_actions": 41,          # 0–20 % in 0.5 pp increments
    "min_rate": 0.0,
    "max_rate": 20.0,
    "omega_pi": 1.0,          # inflation-deviation weight
    "omega_u": 0.5,           # unemployment-deviation weight
    "omega_smooth": 0.1,      # rate-smoothing weight

    # DQN agent
    "buffer_capacity": 15_000,
    "batch_size": 64,
    "gamma": 0.99,
    "epsilon_start": 1.0,
    "epsilon_end": 0.02,
    "epsilon_decay_steps": 8_000,
    "lr": 3e-4,
    "target_update_freq": 400,
    "hidden_dim": 128,
    "train_start_step": 500,
    "n_episodes": 600,

    # Cross-validation folds (expanding window)
    "cv_folds": [
        {"train_end": "1990-12", "test_start": "1991-01",
         "test_end": "1995-12", "name": "Early 1990s"},
        {"train_end": "1995-12", "test_start": "1996-01",
         "test_end": "2000-12", "name": "Late 1990s"},
        {"train_end": "2000-12", "test_start": "2001-01",
         "test_end": "2007-12", "name": "2000s"},
        {"train_end": "2007-12", "test_start": "2008-01",
         "test_end": "2015-12", "name": "GFC & aftermath"},
        {"train_end": "2015-12", "test_start": "2016-01",
         "test_end": "2024-12", "name": "Recent"},
    ],

    "fig_dpi": 200,
    "seed": 42,
}


# ============================================================================
# 1. Data loading
# ============================================================================

def _try_read(data_path, options):
    """Try several filenames; return the first that exists."""
    for name in options:
        p = os.path.join(data_path, name)
        if os.path.exists(p):
            return pd.read_csv(p, parse_dates=["observation_date"],
                               index_col="observation_date")
    raise FileNotFoundError(f"None of {options} found in {data_path}")


def load_data(data_path: str) -> pd.DataFrame:
    """
    Load monthly FRED CSVs and merge into a single panel.

    Variables:
        CPI, unemployment, fed funds rate, capacity utilisation,
        10-year Treasury yield, National Financial Conditions Index
    """
    print("=" * 70)
    print("LOADING MONTHLY FRED DATA")
    print("=" * 70)

    cpi = _try_read(data_path, ["CPIAUCSL.csv"])
    unrate = _try_read(data_path, ["UNRATE.csv"])
    ff = _try_read(data_path, ["FEDFUNDS.csv", "FEDFUNDS-1.csv"])
    tcu = _try_read(data_path, ["TCU.csv"])
    gs10 = _try_read(data_path, ["GS10.csv"])
    nfci_raw = _try_read(data_path, ["NFCI.csv"])

    # NFCI is weekly — resample to month-start
    nfci = nfci_raw.resample("MS").last()

    merged = pd.concat([
        cpi.rename(columns={cpi.columns[0]: "cpi"}),
        unrate.rename(columns={unrate.columns[0]: "unemployment"}),
        ff.rename(columns={ff.columns[0]: "fed_funds"}),
        tcu.rename(columns={tcu.columns[0]: "capacity_util"}),
        gs10.rename(columns={gs10.columns[0]: "treasury_10y"}),
        nfci.rename(columns={nfci.columns[0]: "fin_conditions"}),
    ], axis=1).dropna()

    print(f"  Merged panel: {len(merged)} monthly obs  "
          f"({merged.index.min():%Y-%m} to {merged.index.max():%Y-%m})")
    return merged


# ============================================================================
# 2. Feature engineering
# ============================================================================

def engineer_features(data: pd.DataFrame,
                      natural_rate: float = 5.0) -> pd.DataFrame:
    """
    Create derived indicators and lagged features at two horizons:

    - Realistic lags (18, 24, 30 months): approximate the information
      set available to a central bank after publication delays.
    - Intermediate lags (3, 6, 12 months): provide more recent signal
      that improves forecast quality.  Less realistic about data
      availability, but useful for comparing model performance.

    The target variable is inflation 12 months ahead.
    """
    print("\n" + "=" * 70)
    print("FEATURE ENGINEERING")
    print("=" * 70)

    df = data.copy()

    # Derived indicators
    df["inflation"] = df["cpi"].pct_change(12) * 100
    df["unemp_gap"] = df["unemployment"] - natural_rate
    df["capacity_gap"] = df["capacity_util"] - 100.0
    df["term_spread"] = df["treasury_10y"] - df["fed_funds"]

    # Lagged features — both intermediate and realistic horizons
    lag_vars = ["inflation", "unemp_gap", "capacity_gap",
                "fed_funds", "term_spread", "fin_conditions"]
    all_lags = [3, 6, 12, 18, 24, 30]
    for var in lag_vars:
        for lag in all_lags:
            df[f"L{lag}_{var}"] = df[var].shift(lag)

    # Forward target: inflation 12 months from now
    df["inflation_12m_ahead"] = df["inflation"].shift(-12)

    df = df.dropna()
    print(f"  Final dataset: {len(df)} obs  "
          f"({df.index.min():%Y-%m} to {df.index.max():%Y-%m})")
    print(f"  Lags: {all_lags}  |  Target: inflation at t+12")
    return df


# ============================================================================
# 3. Model specifications
# ============================================================================

SPECS = {
    "SIMPLE": {
        "features": [
            "L18_inflation", "L24_inflation", "L30_inflation",
            "L18_fed_funds", "L24_fed_funds", "L30_fed_funds",
        ],
        "desc": "Core variables, realistic lags only",
    },
    "EXPANDED": {
        "features": [
            "L18_inflation", "L24_inflation", "L30_inflation",
            "L18_unemp_gap", "L24_unemp_gap", "L30_unemp_gap",
            "L18_capacity_gap", "L24_capacity_gap", "L30_capacity_gap",
            "L18_fed_funds", "L24_fed_funds", "L30_fed_funds",
        ],
        "desc": "With real-economy measures, realistic lags only",
    },
    "FULL": {
        "features": [
            "L18_inflation", "L24_inflation", "L30_inflation",
            "L18_unemp_gap", "L24_unemp_gap", "L30_unemp_gap",
            "L18_capacity_gap", "L24_capacity_gap", "L30_capacity_gap",
            "L18_fed_funds", "L24_fed_funds", "L30_fed_funds",
            "L18_term_spread", "L24_term_spread", "L30_term_spread",
            "L18_fin_conditions", "L24_fin_conditions", "L30_fin_conditions",
        ],
        "desc": "Full specification, realistic lags only",
    },
    "INFORMATIVE": {
        "features": [
            # Intermediate lags (more recent signal, less realistic)
            "L3_inflation", "L6_inflation", "L12_inflation",
            "L3_unemp_gap", "L6_unemp_gap", "L12_unemp_gap",
            "L3_capacity_gap", "L6_capacity_gap", "L12_capacity_gap",
            "L3_fed_funds", "L6_fed_funds", "L12_fed_funds",
            "L3_term_spread", "L6_term_spread", "L12_term_spread",
            "L3_fin_conditions", "L6_fin_conditions", "L12_fin_conditions",
            # Realistic lags (longer-term context)
            "L18_inflation", "L24_inflation", "L30_inflation",
            "L18_unemp_gap", "L24_unemp_gap", "L30_unemp_gap",
            "L18_capacity_gap", "L24_capacity_gap", "L30_capacity_gap",
            "L18_fed_funds", "L24_fed_funds", "L30_fed_funds",
            "L18_term_spread", "L24_term_spread", "L30_term_spread",
            "L18_fin_conditions", "L24_fin_conditions", "L30_fin_conditions",
        ],
        "desc": "All variables, intermediate + realistic lags (L3–L30)",
    },
}

TARGET_COL = "inflation_12m_ahead"


# ============================================================================
# 4. Time-series cross-validation
# ============================================================================

def time_series_cv(historical: pd.DataFrame, folds: list) -> dict:
    """
    Expanding-window CV across economic regimes.

    Each fold trains on all data up to `train_end` and tests on the
    subsequent window.  This avoids look-ahead bias and checks whether
    the model generalises across structurally different periods
    (e.g. Volcker disinflation, Great Moderation, GFC, post-COVID).
    """
    print("\n" + "=" * 70)
    print("TIME-SERIES CROSS-VALIDATION")
    print("=" * 70)

    cv_results = {}

    for spec_name, spec in SPECS.items():
        print(f"\n  --- {spec_name} ({len(spec['features'])} features) ---")
        fold_metrics = []

        for fold in folds:
            train = historical.loc[:fold["train_end"]]
            test = historical.loc[fold["test_start"]:fold["test_end"]]
            if len(test) < 10:
                continue

            scaler = StandardScaler()
            X_tr = scaler.fit_transform(train[spec["features"]])
            X_te = scaler.transform(test[spec["features"]])
            y_tr, y_te = train[TARGET_COL], test[TARGET_COL]

            model = MLPRegressor(
                hidden_layer_sizes=(128, 64, 32), activation="relu",
                solver="adam", alpha=0.001, batch_size=32,
                learning_rate="adaptive", learning_rate_init=0.001,
                max_iter=500, early_stopping=True,
                validation_fraction=0.15, n_iter_no_change=20,
                random_state=42, verbose=False,
            )
            model.fit(X_tr, y_tr)
            y_pred = model.predict(X_te)

            mse = mean_squared_error(y_te, y_pred)
            mae = mean_absolute_error(y_te, y_pred)
            r2 = r2_score(y_te, y_pred)
            fold_metrics.append({"fold": fold["name"], "mse": mse,
                                 "mae": mae, "r2": r2})
            print(f"    {fold['name']:18s}  MSE={mse:.4f}  "
                  f"MAE={mae:.4f}  R²={r2:.4f}")

        avg = {k: np.mean([f[k] for f in fold_metrics])
               for k in ("mse", "mae", "r2")}
        cv_results[spec_name] = {"folds": fold_metrics, **avg}
        print(f"    {'Average':18s}  MSE={avg['mse']:.4f}  "
              f"MAE={avg['mae']:.4f}  R²={avg['r2']:.4f}")

    best = min(cv_results, key=lambda s: cv_results[s]["mse"])
    print(f"\n  Best specification: {best}  "
          f"(avg CV MSE = {cv_results[best]['mse']:.4f})")
    return cv_results, best


# ============================================================================
# 5. Final model training
# ============================================================================

def train_final_model(historical, spec_name):
    """Train on 80 % of data, validate on 20 %, return model + scaler."""
    print("\n" + "=" * 70)
    print(f"TRAINING FINAL MODEL: {spec_name}")
    print("=" * 70)

    feats = SPECS[spec_name]["features"]
    X, y = historical[feats], historical[TARGET_COL]
    split = int(len(X) * 0.8)

    scaler = StandardScaler()
    X_tr = scaler.fit_transform(X.iloc[:split])
    X_val = scaler.transform(X.iloc[split:])
    y_tr, y_val = y.iloc[:split], y.iloc[split:]

    model = MLPRegressor(
        hidden_layer_sizes=(128, 64, 32), activation="relu",
        solver="adam", alpha=0.001, batch_size=32,
        learning_rate="adaptive", learning_rate_init=0.001,
        max_iter=500, early_stopping=True,
        validation_fraction=0.15, n_iter_no_change=20,
        random_state=42, verbose=False,
    )
    model.fit(X_tr, y_tr)

    y_pred = model.predict(X_val)
    mse = mean_squared_error(y_val, y_pred)
    r2 = r2_score(y_val, y_pred)
    print(f"  Validation MSE: {mse:.4f}  |  R²: {r2:.4f}")
    return model, scaler, feats


# ============================================================================
# 6. Gymnasium environment
# ============================================================================

class MonetaryPolicyEnv(gym.Env):
    """
    Monthly monetary-policy environment.

    The agent observes (inflation, unemployment, capacity utilisation,
    current rate) and chooses a discrete interest rate.  The chosen rate
    is written into a rolling lag buffer so that it enters the inflation
    forecast model after 18 months — matching the realistic information
    structure.  Unemployment and capacity evolve from the historical
    record (the agent does not control these directly).

    Reward = -(w_pi * (pi - pi*)^2 + w_u * (u - u*)^2 + w_s * (di)^2)
    """

    metadata = {"render_modes": []}

    def __init__(self, model, scaler, feature_cols, historical_df, cfg):
        super().__init__()
        self.model = model
        self.scaler = scaler
        self.feature_cols = feature_cols
        self.hist = historical_df.reset_index(drop=True)

        self.pi_target = cfg["inflation_target"]
        self.u_target = cfg["unemployment_natural"]
        self.w_pi = cfg["omega_pi"]
        self.w_u = cfg["omega_u"]
        self.w_smooth = cfg["omega_smooth"]
        self.max_steps = cfg["max_steps"]

        self.n_actions = cfg["n_actions"]
        self.rate_grid = np.linspace(cfg["min_rate"], cfg["max_rate"],
                                     self.n_actions)
        self.action_space = spaces.Discrete(self.n_actions)
        self.observation_space = spaces.Box(-np.inf, np.inf, (4,),
                                            dtype=np.float32)
        self.reset()

    def reset(self, seed=None, options=None):
        if seed is not None:
            super().reset(seed=seed)
            np.random.seed(seed)

        min_start = 30
        max_start = len(self.hist) - self.max_steps - 12 - 1
        if max_start <= min_start:
            self.start_idx = min_start
        else:
            self.start_idx = np.random.randint(min_start, max_start)

        self.idx = self.start_idx
        self.step_count = 0
        self._done = False

        row = self.hist.iloc[self.idx]
        self.state = np.array([row["inflation"], row["unemployment"],
                               row["capacity_util"], row["fed_funds"]],
                              dtype=np.float32)
        self.prev_rate = row["fed_funds"]

        # Rolling buffer of agent's rate choices (for lag propagation)
        # Pre-fill with historical rates at the start
        self._rate_buffer = deque(
            [self.hist.iloc[max(0, self.idx - i)]["fed_funds"]
             for i in range(31)],
            maxlen=31,
        )

        self.episode_history = []
        self.cumulative_reward = 0.0
        return self.state, {}

    def step(self, action):
        if self._done:
            raise RuntimeError("Episode finished — call reset()")

        rate = float(self.rate_grid[int(action)])
        self.step_count += 1

        # Push the agent's rate into the rolling buffer
        self._rate_buffer.appendleft(rate)

        # Build feature vector using historical data for all variables
        # EXCEPT fed_funds lags, which come from the agent's rate buffer
        row = self.hist.iloc[self.idx]
        feat_dict = {}
        for col in self.feature_cols:
            if col in row.index:
                # Override fed_funds lags with agent's actual choices
                if "fed_funds" in col:
                    lag = int(col.split("_")[0][1:])  # e.g. "L18" -> 18
                    if lag < len(self._rate_buffer):
                        feat_dict[col] = [self._rate_buffer[lag]]
                    else:
                        feat_dict[col] = [row[col]]
                else:
                    feat_dict[col] = [row[col]]
            else:
                feat_dict[col] = [0.0]

        features_scaled = self.scaler.transform(pd.DataFrame(feat_dict))
        next_pi = float(self.model.predict(features_scaled)[0])
        next_pi = np.clip(next_pi, -2.0, 15.0)

        # Unemployment & capacity from historical record
        next_idx = min(self.idx + 1, len(self.hist) - 1)
        next_u = self.hist.iloc[next_idx]["unemployment"]
        next_cap = self.hist.iloc[next_idx]["capacity_util"]

        # Reward
        reward = -(self.w_pi * (next_pi - self.pi_target) ** 2
                   + self.w_u * (next_u - self.u_target) ** 2
                   + self.w_smooth * (rate - self.prev_rate) ** 2)

        self.episode_history.append({
            "inflation": next_pi, "unemployment": next_u,
            "capacity": next_cap, "rate": rate, "reward": reward})
        self.cumulative_reward += reward

        self.state = np.array([next_pi, next_u, next_cap, rate],
                              dtype=np.float32)
        self.prev_rate = rate
        self.idx = next_idx

        self._done = self.step_count >= self.max_steps
        return self.state, reward, self._done, False, {}


# ============================================================================
# 7. DQN agent (PyTorch)
# ============================================================================

Transition = namedtuple("Transition",
                        ("state", "action", "reward", "next_state", "done"))


class ReplayBuffer:
    def __init__(self, capacity):
        self.buf = deque(maxlen=capacity)

    def push(self, *args):
        self.buf.append(Transition(*args))

    def sample(self, n):
        return random.sample(self.buf, n)

    def __len__(self):
        return len(self.buf)


class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden=128):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class DQNAgent:
    def __init__(self, env, cfg):
        self.env = env
        sd = env.observation_space.shape[0]
        ad = env.action_space.n

        self.buffer = ReplayBuffer(cfg["buffer_capacity"])
        self.batch_size = cfg["batch_size"]
        self.gamma = cfg["gamma"]

        self.epsilon = cfg["epsilon_start"]
        self.eps_min = cfg["epsilon_end"]
        self.eps_decay = (cfg["epsilon_end"] / cfg["epsilon_start"]
                          ) ** (1.0 / cfg["epsilon_decay_steps"])

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = QNetwork(sd, ad, cfg["hidden_dim"]).to(self.device)
        self.target_net = QNetwork(sd, ad, cfg["hidden_dim"]).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimiser = optim.Adam(self.policy_net.parameters(), lr=cfg["lr"])
        self.loss_fn = nn.SmoothL1Loss()
        self._steps = 0
        self._target_freq = cfg["target_update_freq"]

    def choose_action(self, state, greedy=False):
        if not greedy and np.random.random() < self.epsilon:
            return self.env.action_space.sample()
        with torch.no_grad():
            t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            return self.policy_net(t).argmax().item()

    def store(self, s, a, r, s2, done):
        self.buffer.push(s, a, r, s2, done)

    def update(self):
        if len(self.buffer) < self.batch_size:
            return None
        batch = Transition(*zip(*self.buffer.sample(self.batch_size)))

        s = torch.FloatTensor(np.array(batch.state)).to(self.device)
        a = torch.LongTensor(batch.action).unsqueeze(1).to(self.device)
        r = torch.FloatTensor(batch.reward).unsqueeze(1).to(self.device)
        s2 = torch.FloatTensor(np.array(batch.next_state)).to(self.device)
        d = torch.BoolTensor(batch.done).unsqueeze(1).to(self.device)

        q = self.policy_net(s).gather(1, a)
        with torch.no_grad():
            q2 = self.target_net(s2).max(1)[0].unsqueeze(1)
            target = r + self.gamma * q2 * (~d)

        loss = self.loss_fn(q, target)
        self.optimiser.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0)
        self.optimiser.step()

        self._steps += 1
        if self._steps % self._target_freq == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
        return loss.item()

    def decay_epsilon(self):
        self.epsilon = max(self.eps_min, self.epsilon * self.eps_decay)


# ============================================================================
# 8. Training loop
# ============================================================================

def train_dqn(env, agent, cfg):
    n_ep = cfg["n_episodes"]
    warmup = cfg["train_start_step"]

    print("\n" + "=" * 70)
    print("TRAINING DQN AGENT")
    print("=" * 70)
    print(f"  Episodes: {n_ep}  |  Steps/ep: {env.max_steps}  "
          f"|  Warmup: {warmup}")

    ep_rewards, ep_losses = [], []
    total_steps = 0

    for ep in range(n_ep):
        s, _ = env.reset()
        ep_r, ep_l, n_upd = 0.0, 0.0, 0

        for _ in range(env.max_steps):
            a = agent.choose_action(s)
            s2, r, done, _, _ = env.step(a)
            agent.store(s, a, r, s2, done)

            if total_steps > warmup:
                loss = agent.update()
                if loss is not None:
                    ep_l += loss
                    n_upd += 1

            s = s2
            ep_r += r
            total_steps += 1
            if done:
                break

        agent.decay_epsilon()
        ep_rewards.append(ep_r)
        ep_losses.append(ep_l / max(n_upd, 1))

        if (ep + 1) % 50 == 0:
            avg = np.mean(ep_rewards[-50:])
            print(f"  ep {ep+1:4d}  |  avg reward {avg:8.2f}  |  "
                  f"eps {agent.epsilon:.4f}")

    print(f"\n  Final 50-episode avg reward: "
          f"{np.mean(ep_rewards[-50:]):.2f}")
    return ep_rewards, ep_losses


# ============================================================================
# 9. Taylor Rule baseline
# ============================================================================

def taylor_rule(pi, u, r_star=2.0, pi_star=2.0, u_star=5.0,
                a_pi=1.5, a_u=0.5):
    """Standard Taylor (1993) rule with unemployment gap."""
    return np.clip(r_star + pi + a_pi * (pi - pi_star)
                   + a_u * (u - u_star), 0, 20)


# ============================================================================
# 10. Historical policy generation
# ============================================================================

def generate_historical_policies(historical, agent, env):
    """
    Walk through the full historical record and record what each policy
    (DQN, Taylor Rule) would have recommended at every month.
    """
    df = historical.copy()
    df["taylor_rate"] = df.apply(
        lambda r: taylor_rule(r["inflation"], r["unemployment"]),
        axis=1)
    df["dqn_rate"] = np.nan

    agent_eps_backup = agent.epsilon
    agent.epsilon = 0.0  # greedy

    for i in range(len(df)):
        if i < 30:
            continue
        row = df.iloc[i]
        state = np.array([row["inflation"], row["unemployment"],
                          row["capacity_util"], row["fed_funds"]],
                         dtype=np.float32)
        action = agent.choose_action(state, greedy=True)
        df.iloc[i, df.columns.get_loc("dqn_rate")] = env.rate_grid[action]

    agent.epsilon = agent_eps_backup
    return df


# ============================================================================
# 11. Visualisation
# ============================================================================

# Palette
C1, C2, C3 = "#2563eb", "#dc2626", "#7c3aed"   # blue, red, purple
C_BG = "#fafafa"

plt.rcParams.update({
    "figure.facecolor": C_BG, "axes.facecolor": C_BG,
    "axes.grid": True, "grid.color": "#e5e7eb", "grid.linewidth": 0.5,
    "font.size": 11, "axes.spines.top": False, "axes.spines.right": False,
})

RECESSIONS = [
    ("1980-01", "1982-11"), ("1990-07", "1991-03"),
    ("2001-03", "2001-11"), ("2007-12", "2009-06"),
    ("2020-02", "2020-04"),
]


def plot_data_overview(hist, save=None, dpi=200):
    """Six-panel overview of the raw economic indicators."""
    fig, axes = plt.subplots(3, 2, figsize=(15, 11))
    fig.suptitle("US Economic Indicators (Monthly, 1973–2025)",
                 fontweight="bold", fontsize=14, y=0.995)

    panels = [
        ("inflation", "Inflation (YoY %)", C1, {"hline": 2.0, "hl": "Target"}),
        ("unemployment", "Unemployment rate (%)", C2, {"hline": 5.0, "hl": "Natural rate"}),
        ("capacity_util", "Capacity utilisation (%)", "#c2410c", {"hline": 100, "hl": "Full capacity"}),
        (["fed_funds", "treasury_10y"], "Interest rates (%)", None, {}),
        ("term_spread", "Term spread (10Y − FFR, pp)", "#7c3aed", {"hline": 0}),
        ("fin_conditions", "Financial Conditions Index", "#c2410c", {"hline": 0, "hl": "Neutral"}),
    ]

    for ax, (col, title, color, opts) in zip(axes.flat, panels):
        if isinstance(col, list):
            ax.plot(hist.index, hist[col[0]], lw=1.5, color=C1, label="Fed funds")
            ax.plot(hist.index, hist[col[1]], lw=1.5, color=C2,
                    alpha=0.7, label="10Y Treasury")
            ax.legend(fontsize=9, frameon=False)
        else:
            ax.plot(hist.index, hist[col], lw=1.5, color=color)
        if "hline" in opts:
            ax.axhline(opts["hline"], color="grey", ls="--", lw=1, alpha=0.6,
                       label=opts.get("hl"))
            if "hl" in opts:
                ax.legend(fontsize=9, frameon=False)
        ax.set_title(title, fontweight="bold", fontsize=11)

    fig.tight_layout()
    if save:
        fig.savefig(save, dpi=dpi, bbox_inches="tight")
    return fig


def plot_cv_results(cv_results, save=None, dpi=200):
    """Bar chart of cross-validation MSE by specification and fold."""
    specs = list(cv_results.keys())
    folds = [f["fold"] for f in cv_results[specs[0]]["folds"]]
    n_folds = len(folds)
    n_specs = len(specs)

    fig, ax = plt.subplots(figsize=(13, 5.5))
    x = np.arange(n_folds)
    width = 0.8 / n_specs
    colors = [C1, C2, C3, "#f59e0b"]  # blue, red, purple, amber

    for i, spec in enumerate(specs):
        mses = [f["mse"] for f in cv_results[spec]["folds"]]
        ax.bar(x + i * width, mses, width, label=spec,
               color=colors[i % len(colors)], alpha=0.7,
               edgecolor="white", lw=1)

    ax.set_xticks(x + width * (n_specs - 1) / 2)
    ax.set_xticklabels(folds, fontsize=10)
    ax.set_ylabel("Test MSE")
    ax.set_title("Time-series CV: forecast error by specification and period",
                 fontweight="bold")
    ax.legend(frameon=False, fontsize=9)
    fig.tight_layout()
    if save:
        fig.savefig(save, dpi=dpi, bbox_inches="tight")
    return fig


def plot_training(rewards, losses, save=None, dpi=200):
    """Training reward and loss curves."""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 7.5))
    w = 30

    ax1.plot(rewards, alpha=0.2, color=C1, lw=0.8)
    if len(rewards) >= w:
        sm = np.convolve(rewards, np.ones(w) / w, mode="valid")
        ax1.plot(range(w - 1, len(rewards)), sm, color=C1, lw=2.5,
                 label=f"{w}-episode avg")
    ax1.set_ylabel("Episode reward")
    ax1.set_title("Training: cumulative reward per episode",
                  fontweight="bold")
    ax1.legend(frameon=False)

    valid = [(i, l) for i, l in enumerate(losses) if l and l > 0]
    if valid and len(valid) >= w:
        ix, vals = zip(*valid)
        sm = np.convolve(vals, np.ones(w) / w, mode="valid")
        ax2.plot(range(ix[0] + w - 1, ix[0] + w - 1 + len(sm)),
                 sm, color=C2, lw=2.5)
    ax2.set_ylabel("TD loss")
    ax2.set_xlabel("Episode")
    ax2.set_title("Training: Huber loss", fontweight="bold")

    fig.tight_layout(h_pad=2.5)
    if save:
        fig.savefig(save, dpi=dpi, bbox_inches="tight")
    return fig


def plot_policy_comparison(df, save=None, dpi=200):
    """Historical interest-rate comparison + deviation panel."""
    fig, axes = plt.subplots(2, 1, figsize=(16, 9), sharex=True)
    fig.suptitle("Monetary policy comparison: DQN vs Taylor Rule vs "
                 "Federal Reserve (1975–2025)",
                 fontweight="bold", fontsize=13, y=0.995)

    # Rates
    axes[0].plot(df.index, df["fed_funds"], lw=2, color=C1,
                 label="Federal Reserve (actual)", alpha=0.9)
    axes[0].plot(df.index, df["taylor_rate"], lw=1.8, color=C2,
                 ls="--", label="Taylor Rule", alpha=0.8)
    axes[0].plot(df.index, df["dqn_rate"], lw=2, color=C3,
                 ls=":", label="DQN agent", alpha=0.85)
    for s, e in RECESSIONS:
        axes[0].axvspan(s, e, alpha=0.12, color="#fca5a5", zorder=0)
    axes[0].set_ylabel("Nominal interest rate (%)")
    axes[0].set_ylim(-1, 22)
    axes[0].legend(loc="upper right", frameon=True, fontsize=10)
    axes[0].set_title("Interest-rate policies", fontweight="bold")

    # Deviations
    t_dev = df["taylor_rate"] - df["fed_funds"]
    d_dev = df["dqn_rate"] - df["fed_funds"]
    axes[1].fill_between(df.index, t_dev, 0, alpha=0.15, color=C2)
    axes[1].fill_between(df.index, d_dev, 0, alpha=0.15, color=C3)
    axes[1].plot(df.index, t_dev, lw=1.5, color=C2,
                 label="Taylor deviation", alpha=0.8)
    axes[1].plot(df.index, d_dev, lw=1.5, color=C3,
                 label="DQN deviation", alpha=0.8)
    axes[1].axhline(0, color=C1, lw=1, alpha=0.4)
    axes[1].set_ylabel("Deviation from Fed rate (pp)")
    axes[1].set_xlabel("Year")
    axes[1].legend(loc="upper right", frameon=True, fontsize=10)
    axes[1].set_title("Policy deviations from actual Federal Reserve decisions",
                      fontweight="bold")

    fig.tight_layout()
    if save:
        fig.savefig(save, dpi=dpi, bbox_inches="tight")
    return fig


def compute_deviation_metrics(df):
    """Compute and print MAD / RMSE of each policy vs the Fed."""
    print("\n" + "=" * 70)
    print("DEVIATION FROM ACTUAL FED DECISIONS")
    print("=" * 70)

    for name, col in [("Taylor Rule", "taylor_rate"),
                      ("DQN Agent", "dqn_rate")]:
        dev = (df[col] - df["fed_funds"]).dropna()
        mad = np.abs(dev).mean()
        rmse = np.sqrt((dev ** 2).mean())
        print(f"  {name:15s}  MAD = {mad:.3f} pp  |  RMSE = {rmse:.3f} pp")


# ============================================================================
# 12. Main
# ============================================================================

def main(cfg=None):
    cfg = cfg or DEFAULT_CFG
    os.makedirs(cfg["output_path"], exist_ok=True)
    out = cfg["output_path"]
    dpi = cfg["fig_dpi"]

    # --- Data ---
    data = load_data(cfg["data_path"])
    historical = engineer_features(data, cfg["unemployment_natural"])

    # --- Descriptive plots ---
    plot_data_overview(historical,
                       save=f"{out}/data_overview.png", dpi=dpi)

    # --- Cross-validation ---
    cv_results, best_spec = time_series_cv(historical, cfg["cv_folds"])
    plot_cv_results(cv_results,
                    save=f"{out}/cv_results.png", dpi=dpi)

    # --- Final model ---
    model, scaler, feats = train_final_model(historical, best_spec)

    # --- Environment + agent ---
    env = MonetaryPolicyEnv(model, scaler, feats, historical, cfg)
    agent = DQNAgent(env, cfg)

    # --- Train ---
    ep_rewards, ep_losses = train_dqn(env, agent, cfg)
    plot_training(ep_rewards, ep_losses,
                  save=f"{out}/training_curves.png", dpi=dpi)

    # --- Historical policy comparison ---
    df = generate_historical_policies(historical, agent, env)
    plot_policy_comparison(df,
                           save=f"{out}/policy_comparison.png", dpi=dpi)
    compute_deviation_metrics(df)

    plt.close("all")

    print("\n" + "=" * 70)
    print(f"All outputs saved to {out}/")
    print("=" * 70)

    return {
        "historical": df, "env": env, "agent": agent,
        "cv_results": cv_results, "best_spec": best_spec,
        "episode_rewards": ep_rewards,
    }


# ============================================================================
# Entry point
# ============================================================================

def _in_notebook():
    """Detect if we're running inside a Jupyter/IPython notebook."""
    try:
        from IPython import get_ipython
        return get_ipython() is not None
    except ImportError:
        return False


if __name__ == "__main__":
    if _in_notebook():
        # Running in Jupyter — edit paths here directly
        cfg = {
            **DEFAULT_CFG,
            "data_path": "/Users/leoss/Downloads",
            "output_path": "/Users/leoss/Desktop/Portfolio/Website-/Central bank/Outputs",
        }
        results = main(cfg)
    else:
        # Running from terminal — use CLI arguments
        parser = argparse.ArgumentParser(
            description="RL for Monetary Policy Optimisation")
        parser.add_argument("--data", default=DEFAULT_CFG["data_path"])
        parser.add_argument("--output", default=DEFAULT_CFG["output_path"])
        parser.add_argument("--episodes", type=int,
                            default=DEFAULT_CFG["n_episodes"])
        args = parser.parse_args()

        cfg = {**DEFAULT_CFG, "data_path": args.data,
               "output_path": args.output, "n_episodes": args.episodes}
        main(cfg)

SyntaxError: invalid character '–' (U+2013) (3491157562.py, line 6)