In [None]:
# run.ipynb - execution/backtest using trained model + MCTS
# Output: Sharpe ratio only

import numpy as np
import torch

from data import MarketFeatureBuilder
from spaces import PortfolioEnv
from models import ModelConfig, AlphaZeroPortfolioModel
from search import MCTS, MCTSConfig


class CandidateActionGenerator:
    """
    Discrete action_id -> target weights.
    Env will enforce constraints/projection/turnover/costs.
    """
    def __init__(self, env: PortfolioEnv, num_actions: int, max_tilt: float = 0.02, top_k: int = 5, seed: int = 7):
        self.env = env
        self.num_actions = num_actions
        self.max_tilt = max_tilt
        self.top_k = top_k
        self.rng = np.random.default_rng(seed)

    def __call__(self, action_id: int) -> np.ndarray:
        w = self.env.weights.copy()

        # Base actions
        if action_id == 0:  # hold
            return w
        if action_id == 1:  # equal-weight
            return np.ones(self.env.N) / self.env.N
        if action_id == 2:  # de-risk
            return 0.8 * w
        if action_id == 3:  # re-risk
            return 1.2 * w

        # Momentum tilt (uses env prices): shift small weight from bottom-K to top-K
        t = self.env.t
        L = 20
        if t - L < 0:
            mom = np.zeros(self.env.N)
        else:
            mom = self.env.logp.iloc[t].to_numpy() - self.env.logp.iloc[t - L].to_numpy()

        idx = np.argsort(mom)
        bottom = idx[: self.top_k]
        top = idx[-self.top_k :]

        tilt = float(self.max_tilt * self.rng.uniform(0.25, 1.0))
        w_new = w.copy()
        w_new[bottom] = np.maximum(0.0, w_new[bottom] - tilt / max(len(bottom), 1))
        w_new[top] += tilt / max(len(top), 1)
        return w_new


def annualized_sharpe(returns: np.ndarray, periods_per_year: int = 252, eps: float = 1e-12) -> float:
    returns = returns[np.isfinite(returns)]
    if returns.size < 2:
        return 0.0
    mu = float(np.mean(returns))
    sd = float(np.std(returns, ddof=1))
    if sd < eps:
        return 0.0
    return (mu / sd) * np.sqrt(periods_per_year)


def run_backtest_sharpe_only(
    prices_df,
    volumes_df,
    checkpoint_path: str = "checkpoint.pt",
    initial_value: float = 2_000_000.0,
    mcts_sims: int = 120,
    mcts_depth: int = 3,
):
    # 1) Compute market features once for the full dataset
    builder = MarketFeatureBuilder(prices_df, volumes_df)
    market_features_df = builder.batch_features()

    # 2) Create environment and reset ONCE (single continuous trajectory)
    env = PortfolioEnv(prices_df, market_features_df, initial_value=initial_value)
    env.reset()

    # 3) Load trained model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(checkpoint_path, map_location=device)
    cfg_dict = ckpt.get("model_cfg", {})
    cfg = ModelConfig(**cfg_dict) if cfg_dict else ModelConfig()

    model = AlphaZeroPortfolioModel(cfg).to(device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    # 4) Build action generator and MCTS planner
    action_gen = CandidateActionGenerator(env, num_actions=cfg.num_actions)
    mcts = MCTS(
        model=model,
        env=env,
        action_generator=action_gen,
        cfg=MCTSConfig(num_simulations=mcts_sims, max_depth=mcts_depth),
    )

    # 5) Iterate through ALL timestamps, collecting realized returns
    realized_returns = []
    done = False
    while not done and env.t < env.T - 1:
        w_target = mcts.run()                  # choose action at time t
        _, reward, done, _ = env.step(w_target)  # realize return for t->t+1
        realized_returns.append(float(reward))

    # 6) Compute Sharpe on ALL realized returns and print only that
    sharpe = annualized_sharpe(np.array(realized_returns, dtype=float))
    print(sharpe)


# Usage:
# run_backtest_sharpe_only(prices_df, volumes_df, checkpoint_path="checkpoint.pt")
