import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# params
FREQ = 365
COST_BPS = 10
WIN_VOL = 20

def max_drawdown(eq):
    peak = eq.cummax()
    dd = eq / peak - 1.0
    return float(dd.min())

def perf_stats(r, freq=FREQ):
    r = r.dropna()
    eq = (1 + r).cumprod()

    ann_ret = float(eq.iloc[-1] ** (freq/len(r)) - 1) if len(r) > 0 else np.nan
    ann_vol = float(r.std() * np.sqrt(freq))
    sharpe = float((r.mean() * freq) / (r.std() * np.sqrt(freq))) if r.std() > 0 else np.nan
    mdd = max_drawdown(eq)

    return {"ann_return": ann_ret, "ann_vol": ann_vol, "sharpe": sharpe, "max_dd": mdd}

def run_backtest(weights, rets, cost_bps=COST_BPS):
    # lag
    w = weights.shift(1).reindex(rets.index).fillna(0.0)

    # long only
    w = w.clip(lower=0.0)

    # norm
    w = w.div(w.sum(axis=1).replace(0, np.nan), axis=0).fillna(0.0)

    # perf
    gross = (w * rets).sum(axis=1)

    # couts
    to = w.diff().abs().sum(axis=1).fillna(0.0)
    cost = to * (cost_bps / 10000.0)

    net = gross - cost
    return net, to, w

# 1) low vol ref
vol_roll = rets.rolling(WIN_VOL).std()
q_cs = vol_roll.quantile(0.5, axis=1)
w_low = vol_roll.le(q_cs, axis=0).astype(float)
w_low = w_low.div(w_low.sum(axis=1).replace(0, np.nan), axis=0).fillna(0.0)

# 2) ML weights (deja fait avant)
W_ml_use = W_ml.reindex(rets.index).fillna(0.0)

# run
r_low, to_low, w_low_n = run_backtest(w_low, rets, cost_bps=COST_BPS)
r_ml,  to_ml,  w_ml_n  = run_backtest(W_ml_use, rets, cost_bps=COST_BPS)

# table
rows = []
for name, r, to in [
    ("low_vol_ref", r_low, to_low),
    ("ml_risk_gate", r_ml, to_ml),
]:
    st = perf_stats(r, freq=FREQ)
    st["turnover_mean"] = float(to.mean())
    st["turnover_med"] = float(to.median())
    st["name"] = name
    rows.append(st)

summary = pd.DataFrame(rows).set_index("name").sort_values("sharpe", ascending=False)
print(summary)

# plot equity
plt.figure(figsize=(12,5))
plt.plot((1 + r_low.fillna(0)).cumprod(), label="low_vol_ref")
plt.plot((1 + r_ml.fillna(0)).cumprod(), label="ml_risk_gate")
plt.legend()
plt.title("Equity curves (net)")
plt.tight_layout()
plt.show()

# diag gate (si dispo)
if "gate_ml" in globals():
    g = gate_ml.reindex(rets.index).fillna(1.0)

    print("\nGate diag")
    print("avg:", float(g.mean()))
    print("min:", float(g.min()), "max:", float(g.max()))
    print("stress days:", int((g < 1.0).sum()), "/", len(g))

    # next day check (simple)
    base_port = (w_low_n * rets).sum(axis=1)
    base_next = base_port.shift(-1)

    a = base_next[g < 1.0].dropna()
    b = base_next[g == 1.0].dropna()

    if len(a) > 30 and len(b) > 30:
        print("\nNext day base ret")
        print("gate<1 mean", float(a.mean()), "median", float(a.median()), "n", len(a))
        print("gate=1 mean", float(b.mean()), "median", float(b.median()), "n", len(b))

    if "p_stress" in globals():
        p = p_stress.reindex(rets.index).ffill()

        plt.figure(figsize=(12,3))
        plt.plot(p.index, p, linewidth=1)
        plt.title("p_stress")
        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(12,2.5))
        plt.plot(g.index, g, linewidth=1)
        plt.title("gate_ml")
        plt.tight_layout()
        plt.show()


              ann_return   ann_vol    sharpe    max_dd  turnover_mean  \

name                                                                    
low_vol_ref     0.102579  0.649064  0.481898 -0.816893       0.088950   
ml_risk_gate   -0.039288  0.662760  0.281045 -0.860716       0.029944   


              turnover_med  
              
name                        
low_vol_ref       0.000000  
ml_risk_gate      0.019853  

![Image test 19](image/image_test_19.png)


Gate diag

avg: 0.879013698630137

min: 0.2 max: 1.0

stress days: 276 / 1825

Next day base ret

gate<1 mean 0.0006442474981637147 median 0.005409334970279042 n 276

gate=1 mean 0.0010002802086446396 median 0.0010632983437200605 n 1548

![Image test 19_1](image/image_test_19_1.png)


![Image test 19_2](image/image_test_19_2.png)

ow_vol_ref fait mieux : ann_return ~ +10.3%, Sharpe ~ 0.48.

ml_risk_gate fait pire : ann_return ~ -3.9%, Sharpe ~ 0.28, drawdown un peu plus bas (plus négatif).

Le gate s’active ~ 15% du temps (276 jours / 1825) et baisse l’expo (min gate = 0.2).

Les stats “lendemain” montrent que les jours gate<1 sont pas spécialement mauvaise après : le lendemain, le rendement moyen du portefeuille de base est un peu plus faible que quand gate=1, mais pas un gros signal. Donc le filtre coupe parfois “pour rien”, et ça peut expliquer la sous-perf.