In [8]:
# robustness_tests.py – drop‑in experiments for your TJR back‑tester
# ---------------------------------------------------------------------
# Assumes the original helper functions, data loaders, and the `backtest` logic
# you pasted earlier live in the PYTHONPATH (helpers.config, data.preload_history,
# tjr_long_signal / tjr_short_signal, etc.).
#
# You only need to:
#   >>> from robustness_tests import run_all_experiments
#   >>> results = await run_all_experiments(hist_df)
#   >>> results["fee_slip"].summary   # or however you want to inspect
# ---------------------------------------------------------------------

from __future__ import annotations

import asyncio
import itertools
import logging
import math
import random
from dataclasses import dataclass, asdict
from typing import Callable, Dict, List, Tuple

import numpy as np
import pandas as pd

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from helpers import config, build_htf_levels, tjr_long_signal, tjr_short_signal
from data import preload_history

# ------------------------------------------------------------------
# 1 | Lightweight position object & utilities
# ------------------------------------------------------------------
@dataclass
class Trade:
    dir: int                # +1 long / ‑1 short
    entry: float
    exit: float
    pnl: float
    risk: float
    time_entry: pd.Timestamp
    time_exit: pd.Timestamp

    def as_dict(self):
        return asdict(self)


# ------------------------------------------------------------------
# 2 | Core back‑test engine with **dependency injection** so we can
#       tweak behaviour without copy‑pasting 300 lines again.
# ------------------------------------------------------------------

def run_backtest(
    price: pd.DataFrame,
    *,
    equity0: float = 1_000.0,
    risk_pct: float = config.RISK_PCT,
    fee_bps: float = 0.0,           # 0.10 % round‑trip → fee_bps=10
    slip_bps: float = 0.0,          # 0.05 % slippage each side → slip_bps=5
    lag_bars: int = 0,              # set to 1 for the lag‑1 experiment
    signal_fn_long: Callable = tjr_long_signal,
    signal_fn_short: Callable = tjr_short_signal,
) -> Tuple[List[Trade], List[float]]:
    """Single‑pass back‑test that returns the list of Trade objects and the
        equity curve (sampled **every bar**). All heavy‑lifting (ATR, EMA, etc.)
        still delegated to the helper functions you already use.
    """

    htf = build_htf_levels(price)
    equity   = equity0
    trades: List[Trade] = []
    curve    = []
    pos: Dict | None = None
    pending: Dict | None = None    # holds a signal that will be executed next bar when lag_bars>0

    for i, (ts, bar) in enumerate(price.iterrows()):
        # ------------------------------------------------ execute pending entry
        if pending and pending["activate_at"] <= i:
            pos = pending["pos"]
            pending = None

        # ------------------------------------------------ manage open position
        if pos:
            hit_sl = (pos["dir"] == 1 and bar.l <= pos["sl"]) or (pos["dir"] == -1 and bar.h >= pos["sl"])
            hit_tp = (pos["dir"] == 1 and bar.h >= pos["tp"]) or (pos["dir"] == -1 and bar.l <= pos["tp"])
            if hit_sl or hit_tp:
                exit_price = pos["sl"] if hit_sl else pos["tp"]
                trade_fee   = (fee_bps + slip_bps) * 1e-4 * exit_price   # exit side
                pnl_raw     = (pos["risk"] * (config.ATR_MULT_TP / config.ATR_MULT_SL)) if hit_tp else -pos["risk"]
                pnl         = pnl_raw - 2 * trade_fee  # entry + exit cost (round‑trip)
                equity     += pnl
                trades.append(
                    Trade(
                        dir         = pos["dir"],
                        entry       = pos["entry"],
                        exit        = exit_price,
                        pnl         = pnl,
                        risk        = pos["risk"],
                        time_entry  = pos["time_entry"],
                        time_exit   = ts,
                    )
                )
                pos = None

        # ------------------------------------------------ look for new signals
        if pos is None and pending is None:
            htf_row = htf.loc[ts]
            if signal_fn_long(price, i, htf_row):
                stop  = config.ATR_MULT_SL * bar.atr * 1.6
                new_pos = dict(
                    dir   = 1,
                    entry = bar.c,
                    sl    = bar.c - stop - config.WICK_BUFFER * bar.atr,
                    tp    = bar.c + config.ATR_MULT_TP * bar.atr,
                    risk  = equity * risk_pct,
                    time_entry=ts,
                )
            elif signal_fn_short(price, i, htf_row):
                stop  = config.ATR_MULT_SL * bar.atr * 1.6
                new_pos = dict(
                    dir   = -1,
                    entry = bar.c,
                    sl    = bar.c + stop + config.WICK_BUFFER * bar.atr,
                    tp    = bar.c - config.ATR_MULT_TP * bar.atr,
                    risk  = equity * risk_pct,
                    time_entry=ts,
                )
            else:
                new_pos = None

            if new_pos:
                if lag_bars:
                    pending = {"activate_at": i + lag_bars, "pos": new_pos}
                else:
                    pos = new_pos

        curve.append(equity)

    return trades, curve


# ------------------------------------------------------------------
# 3 | Performance evaluation (re‑uses your earlier maths)
# ------------------------------------------------------------------

def summarise(trades: List[Trade], curve: List[float], bar_index: pd.DatetimeIndex, start_equity: float) -> pd.DataFrame:
    from math import sqrt

    wins   = [t for t in trades if t.pnl > 0]
    losses = [t for t in trades if t.pnl < 0]

    p_win  = len(wins) / len(trades)
    p_loss = 1 - p_win

    mean_win_r  = np.mean([t.pnl / t.risk for t in wins]) if wins   else 0
    mean_loss_r = np.mean([t.pnl / t.risk for t in losses]) if losses else 0

    expectancy_r = mean_win_r * p_win + mean_loss_r * p_loss

    equity  = pd.Series(curve, index=bar_index)
    dd      = equity.cummax() - equity
    max_dd  = dd.max()
    max_dd_pct = (dd / equity.cummax()).max()

    ulcer = math.sqrt(np.mean(((dd / equity.cummax()) ** 2)))

    rets  = equity.resample("1h").last().pct_change().dropna()
    sharpe = rets.mean() / rets.std() * math.sqrt(365*24)

    downside = rets[rets < 0]
    sortino  = rets.mean() / (downside.std() or 1e-8) * math.sqrt(365*24)

    return pd.DataFrame({
        "metric": [
            "Expectancy (R)", "Win‑rate", "Max DD ($)", "Max DD (%)",
            "Ulcer", "Sharpe", "Sortino", "Trades"
        ],
        "value": [
            expectancy_r, f"{p_win:.1%}", f"${max_dd:,.0f}", f"{max_dd_pct:.2%}",
            ulcer, sharpe, sortino, len(trades)
        ]
    }).set_index("metric")


# ------------------------------------------------------------------
# 4 | Experiment wrappers
# ------------------------------------------------------------------

async def fee_slippage_test(price: pd.DataFrame, fee_round_trip_bps: float = 10, slip_bps_each_side: float = 5):
    trades, curve = run_backtest(price, fee_bps=fee_round_trip_bps/2, slip_bps=slip_bps_each_side)
    return summarise(trades, curve, price.index[:len(curve)], 1_000)


def lag_one_bar_test(price: pd.DataFrame):
    trades, curve = run_backtest(price, lag_bars=1)
    return summarise(trades, curve, price.index[:len(curve)], 1_000)


def monte_carlo_shuffle(price: pd.DataFrame, n_iter: int = 1000):
    trades, curve = run_backtest(price)  # baseline to capture trade list only once

    terminal_equities = []
    for _ in range(n_iter):
        shuffled = trades.copy()
        random.shuffle(shuffled)
        equity = 1_000 + sum(t.pnl for t in shuffled)
        terminal_equities.append(equity)

    median_terminal = np.median(terminal_equities)
    pct_of_baseline = median_terminal / curve[-1] - 1
    return median_terminal, pct_of_baseline


def param_sweep(price: pd.DataFrame,
                tp_sl_grid: List[Tuple[float,float]] = [(2,1)],
                adx_thr: List[int] = [20,25,30],
                stoch_thr: List[int] = [20,30,40]):
    """Very coarse sweep – replace with your own preferred ranges."""

    results = []
    original_vals = (config.ATR_MULT_TP, config.ATR_MULT_SL)

    for tp, sl, adx, stoch in itertools.product([t for t,_ in tp_sl_grid],[s for _,s in tp_sl_grid], adx_thr, stoch_thr):
        # patch globals (ick, but quick):
        config.ATR_MULT_TP = tp
        config.ATR_MULT_SL = sl
        config.ADX_MIN     = adx
        config.STOCH_MIN   = stoch

        trades, curve = run_backtest(price)
        r = summarise(trades, curve, price.index[:len(curve)], 1_000)
        results.append((tp,sl,adx,stoch, r.loc["Expectancy (R)", "value"]))

    # restore
    config.ATR_MULT_TP, config.ATR_MULT_SL = original_vals

    sweep_df = pd.DataFrame(results, columns=["TP","SL","ADX","Stoch","Expectancy_R"])
    return sweep_df


# ------------------------------------------------------------------
# 5 | Convenience orchestrator
# ------------------------------------------------------------------

async def run_all_experiments(price: pd.DataFrame):
    fee_res   = await fee_slippage_test(price)
    lag_res   = lag_one_bar_test(price)  # Removed 'await' as it's not an async function
    med_eq, pct = monte_carlo_shuffle(price)
    sweep_df = param_sweep(price)

    return {
        "fee_slip": fee_res,
        "lag1"    : lag_res,
        "mc_median": (med_eq, pct),
        "sweep"   : sweep_df,
    }
    
    

print("Robustness tests ready. Call `await run_all_experiments(hist_df)` to execute.")
logging.basicConfig(level=logging.INFO)
# Entrypoint for Jupyter Notebook
hist = await preload_history(limit=3000)
hist_30d = hist[hist.index >= hist.index[-1] - pd.Timedelta(days=30)]
# trades,curve = run_backtest(hist_30d)   # pass acces
await run_all_experiments(hist_30d)   # pass accessor



# ------------------------------------------------------------------
# 6 | Draw‑down mitigation ideas
# ------------------------------------------------------------------
# (not executable – just as reference comments)
#
# • Volatility scaling          – shrink risk_pct when rolling σ of equity rises
# • Equity curve stop           – pause trading after X‑σ dip from new high
# • Trailing stop on each trade – step SL to breakeven at +1R, trail by ATR
# • Partial profit              – take half at +1R, let rest run to 2R/3R
# • Reduce position size after large wins (anti‑Kelly)
# • Diversify across pairs      – same edge, uncorrelated instruments

Robustness tests ready. Call `await run_all_experiments(hist_df)` to execute.


{'fee_slip':                       value
 metric                     
 Expectancy (R)      1.73423
 Win‑rate              91.4%
 Max DD ($)             $104
 Max DD (%)            3.98%
 Ulcer              0.005677
 Sharpe            18.938288
 Sortino         5850.951081
 Trades                   35,
 'lag1':                           value
 metric                         
 Expectancy (R)         1.742857
 Win‑rate                  91.4%
 Max DD ($)                 $103
 Max DD (%)                3.96%
 Ulcer                  0.005639
 Sharpe                18.954764
 Sortino         15859131.030326
 Trades                       35,
 'mc_median': (np.float64(3301.7568280612777),
  np.float64(2.220446049250313e-16)),
 'sweep':    TP  SL  ADX  Stoch  Expectancy_R
 0   2   1   20     20      1.918919
 1   2   1   20     30      1.918919
 2   2   1   20     40      1.918919
 3   2   1   25     20      1.918919
 4   2   1   25     30      1.918919
 5   2   1   25     40      1.918919
 6   