In [1]:
# -*- coding: utf-8 -*-
import os
import backtrader as bt
import yfinance as yf
import pandas as pd
import numpy as np

# ================== 基本参数 ==================
start_date = "2025-07-01"
end_date   = "2025-12-17"
initial_capital = 10000  # 初始资金

# ================== 交易成本（全局） ==================
COMMISSION_RATE = 0.001   # 手续费 0.1%
SLIPPAGE_PERC   = 0.001   # 滑点   0.1%

# ================== 保留现金比例（用于手续费/滑点） ==================
CASH_BUFFER = 0.05        # 预留 5% 现金不进市场

# 阈值（A->B 为防守触发，B->A 为进攻恢复）
VIX_THRESHOLD_A2B = 19
DXY_THRESHOLD_A2B = 99
VIX_THRESHOLD_B2A = 23
DXY_THRESHOLD_B2A = 102

# ======== 权重（合计=1）========
VIX_WEIGHT  = 0.3
DXY_WEIGHT  = 0.5
MA_WEIGHT   = 0.2
GOLD_WEIGHT = 0

# 综合评分阈值
SCORE_THRESHOLD_A2B = 0.5
SCORE_THRESHOLD_B2A = 0.5

# 移动平均线周期
MA_SHORT_PERIOD = 20
MA_LONG_PERIOD  = 50

# 冷却期（按交易日计数）
COOLDOWN_PERIOD = 42  # 42 个交易日（bar）

# 策略B配置
STRATEGY_B_ALLOCATION = {
    "SP500":       0.2,
    "Bonds_short": 0.3,
    "Gold":        0.5
}

# 资产代码
assets = {
    "SP500":       "SPY",
    "Nasdaq":      "QQQ",
    "Bonds_short": "SHY",
    "Gold":        "GLD"
}


# =============== 工具函数：处理 yfinance 多重列 ===============
def prepare_ohlc(df: pd.DataFrame) -> pd.DataFrame:
    """把 yfinance 的 MultiIndex 列扁平化，并只保留 OHLCV 列"""
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns]
    rename_map = {c: c.capitalize() for c in df.columns}
    df = df.rename(columns=rename_map)
    keep_cols = [c for c in ["Open", "High", "Low", "Close", "Volume"] if c in df.columns]
    df = df[keep_cols].copy()
    df.index = pd.to_datetime(df.index)
    return df


# =============== 策略 ===============
class CompositeSwitchStrategy(bt.Strategy):
    params = dict(
        VIX_THRESHOLD_A2B  = VIX_THRESHOLD_A2B,
        DXY_THRESHOLD_A2B  = DXY_THRESHOLD_A2B,
        VIX_THRESHOLD_B2A  = VIX_THRESHOLD_B2A,
        DXY_THRESHOLD_B2A  = DXY_THRESHOLD_B2A,
        VIX_WEIGHT         = VIX_WEIGHT,
        DXY_WEIGHT         = DXY_WEIGHT,
        MA_WEIGHT          = MA_WEIGHT,
        GOLD_WEIGHT        = GOLD_WEIGHT,
        SCORE_THRESHOLD_A2B= SCORE_THRESHOLD_A2B,
        SCORE_THRESHOLD_B2A= SCORE_THRESHOLD_B2A,
        MA_SHORT_PERIOD    = MA_SHORT_PERIOD,
        MA_LONG_PERIOD     = MA_LONG_PERIOD,
        COOLDOWN_PERIOD    = COOLDOWN_PERIOD,
        STRATEGY_B_ALLOCATION = STRATEGY_B_ALLOCATION,

        # 使用全局 CASH_BUFFER 计算最大投入仓位比例（预留 5% 现金）
        MAX_TARGET         = 1.0 - CASH_BUFFER,
    )

    def __init__(self):
        # data0: QQQ (Nasdaq)
        # data1: SPY (SP500)
        # data2: SHY (Bonds_short)
        # data3: GLD (Gold)
        # data4: VIX
        # data5: DXY
        self.qqq = self.datas[0]
        self.spy = self.datas[1]
        self.shy = self.datas[2]
        self.gld = self.datas[3]
        self.vix = self.datas[4]
        self.dxy = self.datas[5]

        # 指标
        self.vix_sma5 = bt.ind.SMA(self.vix.close, period=5)
        self.dxy_sma5 = bt.ind.SMA(self.dxy.close, period=5)

        self.spy_ma20  = bt.ind.SMA(self.spy.close, period=self.p.MA_SHORT_PERIOD)
        self.spy_ma50  = bt.ind.SMA(self.spy.close, period=self.p.MA_LONG_PERIOD)
        self.gold_ma20 = bt.ind.SMA(self.gld.close, period=self.p.MA_SHORT_PERIOD)
        self.gold_ma50 = bt.ind.SMA(self.gld.close, period=self.p.MA_LONG_PERIOD)

        # 策略信号状态
        self.current_signal      = 1          # 当前执行策略：1=A, 2=B（状态视角）
        self.final_signal_prev   = 1          # 上一日 Final_Signal
        self.strategy_a_cond_prev= False
        self.strategy_b_cond_prev= False
        self.last_transition_bar = None

        # 初始建仓相关
        self.bar_index        = 0            # 从 1 开始计数
        self.init_q_buy_done  = False        # 是否已经首日发出买入 QQQ 的操作

        # 挂起中的两步调仓计划（None 或 dict）
        self.pending_transition = None

        # ====== 记录用（画图 / 指标） ======
        self.dates = []
        self.nav_list = []
        self.returns = []
        self.last_value = None

        self.qqq_close_list  = []
        self.spy_close_list  = []
        self.gold_close_list = []
        self.vix_list        = []
        self.dxy_list        = []
        self.vix_ma5_list    = []
        self.dxy_ma5_list    = []
        self.sp500_ma20_list = []
        self.sp500_ma50_list = []
        self.gold_ma20_list  = []
        self.gold_ma50_list  = []

        self.composite_list  = []
        self.vix_score_list  = []
        self.dxy_score_list  = []
        self.ma_score_list   = []
        self.gold_score_list = []

        self.signal_list      = []
        self.rebalance_records= []

        # 调试：记录持仓
        self.pos_qqq = []
        self.pos_spy = []
        self.pos_shy = []
        self.pos_gld = []

    # ========= 小工具：策略名称 =========
    def _strategy_name(self, sig):
        if sig == 1:
            return "策略A(100%纳斯达克)"
        elif sig == 2:
            return "策略B(20%标普500+30%短期债券+50%黄金)"
        else:
            return "未知策略"

    # ========= 订单 & 交易回报 =========
    def notify_order(self, order):
        if order.status in [order.Submitted, order.Accepted]:
            return

        exec_dt   = self.data.datetime.date(0)  # 操作日（执行日）
        data_name = order.data._name
        # 从 info 里取信号日（下单日）
        signal_dt = getattr(order.info, 'signal_date', None)
        from_sig  = getattr(order.info, 'from_signal', None)
        to_sig    = getattr(order.info, 'to_signal', None)
        reason    = getattr(order.info, 'reason', '')

        if order.status == order.Completed:
            side = "BUY" if order.isbuy() else "SELL"
            print(f"[{exec_dt}] ORDER {side} {data_name} "
                  f"size={order.executed.size:.2f} "
                  f"price={order.executed.price:.2f} "
                  f"comm={order.executed.comm:.2f} "
                  f"信号日={signal_dt}, 操作日={exec_dt}, "
                  f"from={from_sig}, to={to_sig}, reason={reason}")
        elif order.status == order.Canceled:
            print(f"[{exec_dt}] ORDER CANCELED {data_name} 信号日={signal_dt}")
        elif order.status == order.Rejected:
            print(f"[{exec_dt}] ORDER REJECTED {data_name} 信号日={signal_dt}")

    def notify_trade(self, trade):
        if not trade.isclosed:
            return
        dt = self.data.datetime.date(0)
        print(f"[{dt}] TRADE {trade.data._name} CLOSED "
              f"gross={trade.pnl:.2f}, net={trade.pnlcomm:.2f}")

    # ===== 热身 + 正常阶段都调用同一套逻辑 =====
    def prenext(self):
        self.next_core()

    def next(self):
        self.next_core()

    # ========= 主逻辑 =========
    def next_core(self):
        # bar 计数
        self.bar_index += 1

        dt = self.datas[0].datetime.date(0)
        self.dates.append(dt)

        # ===== 首个交易日：发出策略 A 信号，操作在下一交易日完成（只买，不用两步） =====
        if self.bar_index == 1 and not self.init_q_buy_done:
            signal_dt = dt  # 信号日
            print("\n" + "=" * 80)
            print(f"[{signal_dt}] 初始建仓信号：进入策略A(100%纳斯达克)，操作日=下一交易日(建仓)")

            o = self.order_target_percent(self.qqq, target=self.p.MAX_TARGET)
            if o is not None:
                if isinstance(o, list):
                    for od in o:
                        od.addinfo(signal_date=signal_dt,
                                   from_signal=0,
                                   to_signal=1,
                                   reason="INIT_BUY")
                else:
                    o.addinfo(signal_date=signal_dt,
                              from_signal=0,
                              to_signal=1,
                              reason="INIT_BUY")

            self.current_signal    = 1
            self.final_signal_prev = 1
            self.init_q_buy_done   = True

        # ===== 记录价格 =====
        self.qqq_close_list.append(float(self.qqq.close[0]))
        self.spy_close_list.append(float(self.spy.close[0]))
        self.gold_close_list.append(float(self.gld.close[0]))
        self.vix_list.append(float(self.vix.close[0]))
        self.dxy_list.append(float(self.dxy.close[0]))

        # ===== 记录净值 & 日收益 =====
        value = self.broker.getvalue()
        if self.last_value is None:
            daily_ret = 0.0
        else:
            daily_ret = (value / self.last_value) - 1.0
        self.returns.append(daily_ret)
        self.last_value = value
        self.nav_list.append(value)

        # ===== 记录持仓（调试用） =====
        self.pos_qqq.append(self.getposition(self.qqq).size)
        self.pos_spy.append(self.getposition(self.spy).size)
        self.pos_shy.append(self.getposition(self.shy).size)
        self.pos_gld.append(self.getposition(self.gld).size)

        # ===== 指标未就绪：只记录，不生成新信号 =====
        if (len(self.spy) < self.p.MA_LONG_PERIOD or
            len(self.gld) < self.p.MA_LONG_PERIOD or
            len(self.vix) < 5 or
            len(self.dxy) < 5):

            self.vix_ma5_list.append(None)
            self.dxy_ma5_list.append(None)
            self.sp500_ma20_list.append(None)
            self.sp500_ma50_list.append(None)
            self.gold_ma20_list.append(None)
            self.gold_ma50_list.append(None)

            self.composite_list.append(None)
            self.vix_score_list.append(None)
            self.dxy_score_list.append(None)
            self.ma_score_list.append(None)
            self.gold_score_list.append(None)

            self.signal_list.append(self.current_signal)

            # 即便指标未就绪，也要检查是否有挂起的两步调仓（几乎不会发生，但防御性写法）
            self._maybe_do_pending_transition()
            return

        # ===== 计算综合评分 =====
        vix_5ma      = float(self.vix_sma5[0])
        dxy_5ma      = float(self.dxy_sma5[0])
        sp500_ma20   = float(self.spy_ma20[0])
        sp500_ma50   = float(self.spy_ma50[0])
        gold_ma20    = float(self.gold_ma20[0])
        gold_ma50    = float(self.gold_ma50[0])

        self.vix_ma5_list.append(vix_5ma)
        self.dxy_ma5_list.append(dxy_5ma)
        self.sp500_ma20_list.append(sp500_ma20)
        self.sp500_ma50_list.append(sp500_ma50)
        self.gold_ma20_list.append(gold_ma20)
        self.gold_ma50_list.append(gold_ma50)

        def clip01(x):
            return max(0.0, min(1.0, x))

        vix_score = (vix_5ma - self.p.VIX_THRESHOLD_A2B) / \
                    (self.p.VIX_THRESHOLD_B2A - self.p.VIX_THRESHOLD_A2B)
        vix_score = clip01(vix_score)

        dxy_score = (dxy_5ma - self.p.DXY_THRESHOLD_A2B) / \
                    (self.p.DXY_THRESHOLD_B2A - self.p.DXY_THRESHOLD_A2B)
        dxy_score = clip01(dxy_score)

        ma_ratio = sp500_ma20 / sp500_ma50 if sp500_ma50 != 0 else 1.0
        ma_score = 1 - (ma_ratio - 0.9) / (1.1 - 0.9)
        ma_score = clip01(ma_score)

        gold_ratio = gold_ma20 / gold_ma50 if gold_ma50 != 0 else 1.0
        gold_score = (gold_ratio - 1.0) / (1.2 - 1.0)
        gold_score = clip01(gold_score)

        composite = (
            vix_score  * self.p.VIX_WEIGHT  +
            dxy_score  * self.p.DXY_WEIGHT  +
            ma_score   * self.p.MA_WEIGHT   +
            gold_score * self.p.GOLD_WEIGHT
        )

        self.composite_list.append(composite)
        self.vix_score_list.append(vix_score)
        self.dxy_score_list.append(dxy_score)
        self.ma_score_list.append(ma_score)
        self.gold_score_list.append(gold_score)

        # ===== 生成 Raw / Final_Signal =====
        A2B_condition = composite > self.p.SCORE_THRESHOLD_A2B   # A->B
        B2A_condition = composite < self.p.SCORE_THRESHOLD_B2A   # B->A

        strategy_a_cond = B2A_condition
        strategy_b_cond = A2B_condition

        bar_idx = self.bar_index  # 用内部计数更直观

        if bar_idx == 1:
            raw_signal = 0
            final_signal = 1
            self.last_transition_bar = bar_idx
        else:
            raw_signal = 0
            if self.strategy_a_cond_prev:
                raw_signal = 1
            elif self.strategy_b_cond_prev:
                raw_signal = 2

            prev_final = self.final_signal_prev

            # 冷却期：按交易日 bar 计
            if (self.last_transition_bar is not None and
                (bar_idx - self.last_transition_bar) < self.p.COOLDOWN_PERIOD):
                final_signal = prev_final
            else:
                if raw_signal != 0 and raw_signal != prev_final:
                    final_signal = raw_signal
                    self.last_transition_bar = bar_idx
                else:
                    final_signal = prev_final

        self.strategy_a_cond_prev = strategy_a_cond
        self.strategy_b_cond_prev = strategy_b_cond
        self.final_signal_prev   = final_signal
        self.signal_list.append(final_signal)

        # ===== 如果 Final_Signal 变化：生成“两步调仓计划” =====
        if final_signal != self.current_signal:
            from_sig  = self.current_signal
            to_sig    = final_signal
            signal_dt = dt  # 信号日

            print("\n" + "=" * 80)
            print(f"[{signal_dt}] 生成切换信号(日)：{from_sig} → {to_sig}，"
                  f"操作日1=下一交易日(清仓旧策略)，操作日2=再下一交易日(建仓新策略)")

            self.current_signal = to_sig

            # 挂起两步调仓计划
            self.pending_transition = dict(
                signal_date = signal_dt,
                from_signal = from_sig,
                to_signal   = to_sig,
                composite   = composite,
                vix_5ma     = vix_5ma,
                dxy_5ma     = dxy_5ma,
                vix_score   = vix_score,
                dxy_score   = dxy_score,
                ma_score    = ma_score,
                gold_score  = gold_score,
                sell_done   = False,
                buy_done    = False,
                sell_bar    = self.bar_index + 1,    # 操作日1：清仓
                buy_bar     = self.bar_index + 2,    # 操作日2：建仓
            )

        # ===== 最后：检查是否到了操作日1 / 操作日2，执行挂起的两步调仓 =====
        self._maybe_do_pending_transition()

    # ========= 执行挂起的两步调仓 =========
    def _maybe_do_pending_transition(self):
        if self.pending_transition is None:
            return

        pt = self.pending_transition

        # 操作日1：清仓旧策略
        if (not pt["sell_done"]) and (self.bar_index >= pt["sell_bar"]):
            self._do_sell_step(pt)
            pt["sell_done"] = True

        # 操作日2：建仓新策略
        if pt["sell_done"] and (not pt["buy_done"]) and (self.bar_index >= pt["buy_bar"]):
            self._do_buy_step(pt)
            pt["buy_done"] = True

        # 两步都完成，清空挂起
        if pt["sell_done"] and pt["buy_done"]:
            self.pending_transition = None

    # ========= 操作日1：清仓旧策略 =========
    def _do_sell_step(self, pt):
        op_dt     = self.datas[0].datetime.date(0)
        from_sig  = pt["from_signal"]
        to_sig    = pt["to_signal"]
        signal_dt = pt["signal_date"]

        from_str = self._strategy_name(from_sig)
        to_str   = self._strategy_name(to_sig)

        print("-" * 80)
        print(f"[{op_dt}] 操作日1-清仓：从 {from_str} 切换到 {to_str}，信号日={signal_dt}")

        # 需要清仓的资产列表
        if from_sig == 1:      # 策略 A：只有 QQQ
            sell_list = [self.qqq]
        elif from_sig == 2:    # 策略 B：SPY + SHY + GLD
            sell_list = [self.spy, self.shy, self.gld]
        else:                  # 兜底：全部清掉
            sell_list = [self.qqq, self.spy, self.shy, self.gld]

        for data in sell_list:
            pos      = self.getposition(data)
            cur_size = pos.size
            cur_val  = cur_size * data.close[0]
            if cur_size == 0:
                continue

            print(f"    清仓 {data._name}: 当前市值={cur_val:,.2f} → 目标=0")
            o = self.order_target_percent(data=data, target=0.0)
            if o is not None:
                if isinstance(o, list):
                    for od in o:
                        od.addinfo(signal_date=signal_dt,
                                   from_signal=from_sig,
                                   to_signal=to_sig,
                                   reason="SELL_STEP1")
                else:
                    o.addinfo(signal_date=signal_dt,
                              from_signal=from_sig,
                              to_signal=to_sig,
                              reason="SELL_STEP1")

        # 记录调仓（操作日1）
        self.rebalance_records.append({
            "SignalDate": signal_dt,
            "OperationDate": op_dt,
            "Step": "SELL",
            "From": from_str,
            "To": to_str,
            "Composite": pt["composite"],
            "VIX_5MA": pt["vix_5ma"],
            "DXY_5MA": pt["dxy_5ma"],
            "VIX_score": pt["vix_score"],
            "DXY_score": pt["dxy_score"],
            "MA_score": pt["ma_score"],
            "GOLD_score": pt["gold_score"],
        })

    # ========= 操作日2：建仓新策略 =========
    def _do_buy_step(self, pt):
        op_dt     = self.datas[0].datetime.date(0)
        from_sig  = pt["from_signal"]
        to_sig    = pt["to_signal"]
        signal_dt = pt["signal_date"]

        from_str = self._strategy_name(from_sig)
        to_str   = self._strategy_name(to_sig)

        print("-" * 80)
        print(f"[{op_dt}] 操作日2-建仓：从 {from_str} 切换到 {to_str}，信号日={signal_dt}")

        max_total = self.p.MAX_TARGET

        # 目标权重（只针对新策略持仓资产）
        target_map = {}

        if to_sig == 1:
            # 策略 A：几乎满仓 QQQ
            target_map[self.qqq] = max_total
        elif to_sig == 2:
            # 策略 B：分散配置
            alloc = self.p.STRATEGY_B_ALLOCATION
            target_map[self.spy] = alloc["SP500"]       * max_total
            target_map[self.shy] = alloc["Bonds_short"] * max_total
            target_map[self.gld] = alloc["Gold"]        * max_total
        else:
            target_map = {}

        port_val = self.broker.getvalue()
        cash     = self.broker.getcash()

        for data, tgt in target_map.items():
            pos       = self.getposition(data)
            cur_size  = pos.size
            cur_price = data.close[0]
            cur_val   = cur_size * cur_price
            tgt_val   = port_val * tgt
            diff_val  = tgt_val - cur_val

            action = "HOLD"
            if diff_val > 0:
                action = "BUY"
            elif diff_val < 0:
                action = "SELL"

            print(f"    {data._name}: 当前市值={cur_val:,.2f}, "
                  f"目标市值={tgt_val:,.2f}, 差额={diff_val:,.2f} → {action} "
                  f"(总资产={port_val:,.2f}, 现金={cash:,.2f})")

            o = self.order_target_percent(data=data, target=tgt)
            if o is not None:
                if isinstance(o, list):
                    for od in o:
                        od.addinfo(signal_date=signal_dt,
                                   from_signal=from_sig,
                                   to_signal=to_sig,
                                   reason="BUY_STEP2")
                else:
                    o.addinfo(signal_date=signal_dt,
                              from_signal=from_sig,
                              to_signal=to_sig,
                              reason="BUY_STEP2")

        # 记录调仓（操作日2）
        self.rebalance_records.append({
            "SignalDate": signal_dt,
            "OperationDate": op_dt,
            "Step": "BUY",
            "From": from_str,
            "To": to_str,
            "Composite": pt["composite"],
            "VIX_5MA": pt["vix_5ma"],
            "DXY_5MA": pt["dxy_5ma"],
            "VIX_score": pt["vix_score"],
            "DXY_score": pt["dxy_score"],
            "MA_score": pt["ma_score"],
            "GOLD_score": pt["gold_score"],
        })


# =============== 下载数据并保存到 CSV ===============
print("正在下载数据...")

spy_raw = yf.download(assets["SP500"],       start=start_date, end=end_date)
qqq_raw = yf.download(assets["Nasdaq"],      start=start_date, end=end_date)
shy_raw = yf.download(assets["Bonds_short"], start=start_date, end=end_date)
gld_raw = yf.download(assets["Gold"],        start=start_date, end=end_date)

vix_raw = yf.download("^VIX",      start=start_date, end=end_date)
dxy_raw = yf.download("DX-Y.NYB",  start=start_date, end=end_date)

print("数据下载完成")

# === 保存到 ./data/开始_结束/ 股票.csv（若存在则覆盖） ===
data_dir = os.path.join("./data", f"{start_date}_{end_date}")
os.makedirs(data_dir, exist_ok=True)

save_map = {
    "SPY": spy_raw,
    "QQQ": qqq_raw,
    "SHY": shy_raw,
    "GLD": gld_raw,
    "VIX": vix_raw,
    "DXY": dxy_raw,
}

for name, df in save_map.items():
    file_path = os.path.join(data_dir, f"{name}.csv")
    df.to_csv(file_path)
    print(f"已保存数据: {file_path}")

# =============== 转换为 backtrader 数据源 ===============
spy_df = prepare_ohlc(spy_raw)
qqq_df = prepare_ohlc(qqq_raw)
shy_df = prepare_ohlc(shy_raw)
gld_df = prepare_ohlc(gld_raw)
vix_df = prepare_ohlc(vix_raw)
dxy_df = prepare_ohlc(dxy_raw)

# =============== 构建 Cerebro ===============
cerebro = bt.Cerebro()
cerebro.broker.setcash(initial_capital)

# 手续费 / 滑点 使用全局常量
cerebro.broker.setcommission(commission=COMMISSION_RATE)
cerebro.broker.set_slippage_perc(perc=SLIPPAGE_PERC)

# 不启用 cheat-on-close：保持默认 T+1，
# 通过 signal_date / exec_date 区分“信号日”和“操作日”

data_qqq = bt.feeds.PandasData(dataname=qqq_df)
data_spy = bt.feeds.PandasData(dataname=spy_df)
data_shy = bt.feeds.PandasData(dataname=shy_df)
data_gld = bt.feeds.PandasData(dataname=gld_df)
data_vix = bt.feeds.PandasData(dataname=vix_df)
data_dxy = bt.feeds.PandasData(dataname=dxy_df)

# 顺序要和策略 __init__ 对应
cerebro.adddata(data_qqq, name="QQQ")
cerebro.adddata(data_spy, name="SPY")
cerebro.adddata(data_shy, name="SHY")
cerebro.adddata(data_gld, name="GLD")
cerebro.adddata(data_vix, name="VIX")
cerebro.adddata(data_dxy, name="DXY")

cerebro.addstrategy(CompositeSwitchStrategy)

# =============== 运行回测 ===============
print("开始回测...")
results = cerebro.run()
strat = results[0]
final_value = cerebro.broker.getvalue()
print("回测结束。")

# =============== 绩效指标（252 交易日年化） ===============
returns = pd.Series(strat.returns, index=pd.to_datetime(strat.dates))
returns.name = "Strategy_Return"
cum = (1 + returns).cumprod()
cumulative_return = cum.iloc[-1] - 1

total_days = len(returns)
years = total_days / 252.0 if total_days > 0 else 0

# 若回测 >= 1 年：使用标准年化收益
# 若回测 < 1 年：使用“YTD收益”（即这段期间的总收益，不做幂次年化）
if years >= 1:
    annual_return = (1 + cumulative_return) ** (1 / years) - 1
    annual_ret_label = "年化收益率(252交易日)"
else:
    annual_return = cumulative_return
    annual_ret_label = "YTD收益率(未满1年，不做年化)"

# 年化波动率 / 夏普仍使用 252 年化
annual_vol = returns.std() * np.sqrt(252)

risk_free_rate = 0.02
excess = returns - risk_free_rate / 252.0
sharpe = excess.mean() / excess.std() * np.sqrt(252) if excess.std() != 0 else 0.0

# 回撤序列
running_max = cum.cummax()
drawdown = (cum - running_max) / running_max    # 为负数或0
max_drawdown = drawdown.min()                   # 最大回撤（负值）

# ===== Ulcer Index（溃疡指数） =====
# 定义：sqrt( 负回撤平方的平均值 )，这里用十进制形式（打印时*100）
dd_neg = drawdown[drawdown < 0]
if not dd_neg.empty:
    ulcer_index = np.sqrt((dd_neg ** 2).mean())
else:
    ulcer_index = 0.0

# ===== CDaR 95%（Conditional Drawdown at Risk）=====
# 简单定义：在所有负回撤绝对值中，取 95% 分位以上的那些，再取平均
# 结果为正数，表示“极端回撤条件下的平均回撤”
if not dd_neg.empty:
    dd_abs = (-dd_neg)  # 转为正数便于理解
    alpha = 0.95
    threshold = dd_abs.quantile(alpha)          # 95% 分位
    tail = dd_abs[dd_abs >= threshold]
    cdar_95 = tail.mean() if not tail.empty else 0.0
else:
    cdar_95 = 0.0

# ===== Alpha / Beta =====
spy_close = pd.Series(strat.spy_close_list, index=returns.index)
spy_ret   = spy_close.pct_change().fillna(0.0)
valid     = (~returns.isna()) & (~spy_ret.isna())
cov       = np.cov(returns[valid], spy_ret[valid])[0, 1]
var_spy   = np.var(spy_ret[valid])
beta      = cov / var_spy if var_spy != 0 else 0.0
alpha     = ((returns[valid].mean() - risk_free_rate / 252.0)
             - beta * (spy_ret[valid].mean() - risk_free_rate / 252.0)) * 252.0

yearly_returns = returns.groupby(returns.index.year).apply(lambda x: (1 + x).prod() - 1)

print("\n=== 策略回测结果（Backtrader） ===")
print(f"回测期间: {start_date} 至 {end_date}")
print(f"初始资金: ${initial_capital:,.2f}")
print(f"最终价值: ${final_value:,.2f}")
print(f"累计收益率: {cumulative_return * 100:.2f}%")
print(f"{annual_ret_label}: {annual_return * 100:.2f}%")
print(f"年化波动率: {annual_vol * 100:.2f}%")
print(f"夏普比率: {sharpe:.2f}")
print(f"最大回撤: {max_drawdown * 100:.2f}%")
print(f"Ulcer Index(溃疡指数): {ulcer_index * 100:.2f}%")
print(f"CDaR 95%: {cdar_95 * 100:.2f}%")
print(f"Alpha: {alpha * 100:.2f}%")
print(f"Beta: {beta:.2f}")

print("\n年度收益:")
for year, ret in yearly_returns.items():
    print(f"{year}: {ret * 100:.2f}%")

# =============== 导出调仓 / 持仓日志 ===============
rebalance_df = pd.DataFrame(strat.rebalance_records)
rebalance_path = os.path.join(data_dir, "rebalance_log.csv")
rebalance_df.to_csv(rebalance_path, index=False, encoding="utf-8-sig")
print(f"\n已导出调仓记录: {rebalance_path}")

pos_df = pd.DataFrame({
    "Date": strat.dates,
    "Pos_QQQ": strat.pos_qqq,
    "Pos_SPY": strat.pos_spy,
    "Pos_SHY": strat.pos_shy,
    "Pos_GLD": strat.pos_gld,
})
pos_path = os.path.join(data_dir, "position_log.csv")
pos_df.to_csv(pos_path, index=False, encoding="utf-8-sig")
print(f"已导出每日持仓记录: {pos_path}")


正在下载数据...
YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed

数据下载完成
已保存数据: ./data/2025-07-01_2025-12-17/SPY.csv
已保存数据: ./data/2025-07-01_2025-12-17/QQQ.csv
已保存数据: ./data/2025-07-01_2025-12-17/SHY.csv
已保存数据: ./data/2025-07-01_2025-12-17/GLD.csv
已保存数据: ./data/2025-07-01_2025-12-17/VIX.csv
已保存数据: ./data/2025-07-01_2025-12-17/DXY.csv
开始回测...

[2025-07-01] 初始建仓信号：进入策略A(100%纳斯达克)，操作日=下一交易日(建仓)
[2025-07-02] ORDER BUY QQQ size=17.00 price=546.07 comm=9.28 信号日=2025-07-01, 操作日=2025-07-02, from=0, to=1, reason=INIT_BUY

[2025-11-21] 生成切换信号(日)：1 → 2，操作日1=下一交易日(清仓旧策略)，操作日2=再下一交易日(建仓新策略)
--------------------------------------------------------------------------------
[2025-11-24] 操作日1-清仓：从 策略A(100%纳斯达克) 切换到 策略B(20%标普500+30%短期债券+50%黄金)，信号日=2025-11-21
    清仓 QQQ: 当前市值=10,287.72 → 目标=0
[2025-11-25] ORDER SELL QQQ size=-17.00 price=602.91 comm=10.25 信号日=2025-11-21, 操作日=2025-11-25, from=1, to=2, reason=SELL_STEP1
[2025-11-25] TRADE QQQ CLOSED gross=966.17, net=946.64
--------------------------------------------------------------------------------
[2025-11-25] 操作日2




In [2]:
import os
from pyecharts import options as opts
from pyecharts.charts import Line, Scatter, Pie, Grid

# 确保目录存在
os.makedirs("result", exist_ok=True)

def date_str_list(dates):
    return [d.strftime("%Y-%m-%d") for d in dates]

x = date_str_list(strat.dates)

# 通用：粗线样式
LINE_WIDTH = 2


# ================== 1. NAV 曲线（策略 + QQQ + SPY + 调仓标记） ==================
strategy_nav = strat.nav_list
qqq_px = strat.qqq_close_list
spy_px = strat.spy_close_list

if len(qqq_px) > 0 and len(spy_px) > 0:
    base_qqq = qqq_px[0]
    base_spy = spy_px[0]
    qqq_nav = [p / base_qqq * initial_capital for p in qqq_px]
    spy_nav = [p / base_spy * initial_capital for p in spy_px]
else:
    qqq_nav = [None] * len(x)
    spy_nav = [None] * len(x)

line_nav = (
    Line(init_opts=opts.InitOpts(width="1200px", height="700px"))
    .add_xaxis(x)
    .add_yaxis(
        "Strategy",
        strategy_nav,
        is_smooth=True,
        is_symbol_show=False,  # 不显示圆点
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "Nasdaq(QQQ)",
        qqq_nav,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "S&P500(SPY)",
        spy_nav,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(
            title="NAV Curve (Strategy vs QQQ vs SPY)"
        ),
        xaxis_opts=opts.AxisOpts(
            type_="category",
            axislabel_opts=opts.LabelOpts(rotate=45)
        ),
        yaxis_opts=opts.AxisOpts(name="Net Asset Value"),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="5%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)

# # 调仓标记（三角形保留）
# nav_by_date = {d: v for d, v in zip(strat.dates, strategy_nav)}
# ab_dates, ab_values = [], []
# ba_dates, ba_values = [], []

# 调仓标记（三角形）——按“信号日 SignalDate”标记，不用操作日
nav_by_date = {d: v for d, v in zip(strat.dates, strategy_nav)}
ab_dates, ab_values = [], []
ba_dates, ba_values = [], []

# 用于避免重复标记同一信号日
seen_signals = set()

for r in strat.rebalance_records:

    signal_dt = r.get("SignalDate")
    if signal_dt is None:
        continue

    # 避免重复(SELL/BUY各有一条记录)
    if signal_dt in seen_signals:
        continue
    seen_signals.add(signal_dt)

    nav_ = nav_by_date.get(signal_dt)
    if nav_ is None:
        continue

    d_str = signal_dt.strftime("%Y-%m-%d")

    from_str = r.get("From", "")
    to_str   = r.get("To", "")

    # 判断方向：A→B
    if "策略A" in from_str and "策略B" in to_str:
        ab_dates.append(d_str)
        ab_values.append(nav_)

    # 判断方向：B→A
    elif "策略B" in from_str and "策略A" in to_str:
        ba_dates.append(d_str)
        ba_values.append(nav_)

scatter_ab = (
    Scatter()
    .add_xaxis(ab_dates)
    .add_yaxis(
        "A→B",
        ab_values,
        symbol="triangle",
        symbol_size=10,
        itemstyle_opts=opts.ItemStyleOpts(color="red"),
        label_opts=opts.LabelOpts(is_show=False),
    )
)

scatter_ba = (
    Scatter()
    .add_xaxis(ba_dates)
    .add_yaxis(
        "B→A",
        ba_values,
        symbol="triangle-down",
        symbol_size=10,
        itemstyle_opts=opts.ItemStyleOpts(color="green"),
        label_opts=opts.LabelOpts(is_show=False),
    )
)

nav_chart = line_nav.overlap(scatter_ab).overlap(scatter_ba)
nav_chart.render("result/-result-nav_curve_bt.html")
nav_chart.render("result/nav_curve_bt.html")

# ================== 2. 综合评分 + 组件评分 ==================
comp = strat.composite_list
vix_sc = strat.vix_score_list
dxy_sc = strat.dxy_score_list
ma_sc = strat.ma_score_list
gold_sc = strat.gold_score_list

line_composite = (
    Line(init_opts=opts.InitOpts(width="1200px", height="900px"))
    .add_xaxis(x)
    .add_yaxis(
        "Composite Score",
        comp,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"A2B Threshold({SCORE_THRESHOLD_A2B})",
        [SCORE_THRESHOLD_A2B] * len(x),
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH, type_="dashed"),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"B2A Threshold({SCORE_THRESHOLD_B2A})",
        [SCORE_THRESHOLD_B2A] * len(x),
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH, type_="dashed"),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="Composite Score (VIX / DXY / SP500-MA / GOLD-MA)"),
        xaxis_opts=opts.AxisOpts(type_="category",
                                 axislabel_opts=opts.LabelOpts(rotate=45)),
        yaxis_opts=opts.AxisOpts(name="Score", min_=0, max_=1),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="3%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)

line_components = (
    Line(init_opts=opts.InitOpts(width="1200px", height="900px"))
    .add_xaxis(x)
    .add_yaxis(
        "VIX Score",
        vix_sc,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "DXY Score",
        dxy_sc,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "SP500 MA Score",
        ma_sc,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "GOLD MA Score",
        gold_sc,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="Component Scores"),
        xaxis_opts=opts.AxisOpts(type_="category",
                                 axislabel_opts=opts.LabelOpts(rotate=45)),
        yaxis_opts=opts.AxisOpts(name="Score", min_=0, max_=1),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="5%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)

grid_scores = Grid(init_opts=opts.InitOpts(width="1200px", height="900px"))
grid_scores.add(line_composite, grid_opts=opts.GridOpts(pos_bottom="55%"))
grid_scores.add(line_components, grid_opts=opts.GridOpts(pos_top="55%"))
grid_scores.render("result/composite_scores_bt.html")


# ================== 3. VIX / DXY 及均线 ==================
vix_vals = strat.vix_list
vix_ma5 = strat.vix_ma5_list
dxy_vals = strat.dxy_list
dxy_ma5 = strat.dxy_ma5_list

line_vix = (
    Line(init_opts=opts.InitOpts(width="1200px", height="900px"))
    .add_xaxis(x)
    .add_yaxis(
        "VIX Index",
        vix_vals,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "VIX 5-Day MA",
        vix_ma5,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"A2B VIX({VIX_THRESHOLD_A2B})",
        [VIX_THRESHOLD_A2B] * len(x),
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH, type_="dashed"),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"B2A VIX({VIX_THRESHOLD_B2A})",
        [VIX_THRESHOLD_B2A] * len(x),
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH, type_="dashed"),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="VIX Index and 5-Day Moving Average"),
        xaxis_opts=opts.AxisOpts(type_="category",
                                 axislabel_opts=opts.LabelOpts(rotate=45)),
        yaxis_opts=opts.AxisOpts(name="Index Value"),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="3%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)

line_dxy = (
    Line(init_opts=opts.InitOpts(width="1200px", height="900px"))
    .add_xaxis(x)
    .add_yaxis(
        "Dollar Index(DXY)",
        dxy_vals,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "DXY 5-Day MA",
        dxy_ma5,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"A2B DXY({DXY_THRESHOLD_A2B})",
        [DXY_THRESHOLD_A2B] * len(x),
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH, type_="dashed"),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"B2A DXY({DXY_THRESHOLD_B2A})",
        [DXY_THRESHOLD_B2A] * len(x),
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH, type_="dashed"),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="Dollar Index and 5-Day Moving Average"),
        xaxis_opts=opts.AxisOpts(type_="category",
                                 axislabel_opts=opts.LabelOpts(rotate=45)),
        yaxis_opts=opts.AxisOpts(name="Index Value"),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="5%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)

grid_vix_dxy = Grid(init_opts=opts.InitOpts(width="1200px", height="900px"))
grid_vix_dxy.add(line_vix, grid_opts=opts.GridOpts(pos_bottom="55%"))
grid_vix_dxy.add(line_dxy, grid_opts=opts.GridOpts(pos_top="55%"))
grid_vix_dxy.render("result/vix_dxy_bt.html")


# ================== 4. S&P500 均线 ==================
line_sp500 = (
    Line(init_opts=opts.InitOpts(width="1200px", height="700px"))
    .add_xaxis(x)
    .add_yaxis(
        "S&P500",
        strat.spy_close_list,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"MA{MA_SHORT_PERIOD}",
        strat.sp500_ma20_list,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        f"MA{MA_LONG_PERIOD}",
        strat.sp500_ma50_list,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="S&P500 Price and Moving Averages"),
        xaxis_opts=opts.AxisOpts(type_="category",
                                 axislabel_opts=opts.LabelOpts(rotate=45)),
        yaxis_opts=opts.AxisOpts(name="Price"),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="5%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)
line_sp500.render("result/sp500_ma_bt.html")


# ================== 5. Gold 均线 ==================
line_gold = (
    Line(init_opts=opts.InitOpts(width="1200px", height="700px"))
    .add_xaxis(x)
    .add_yaxis(
        "Gold (GLD)",
        strat.gold_close_list,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "Gold MA20",
        strat.gold_ma20_list,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        "Gold MA50",
        strat.gold_ma50_list,
        is_smooth=True,
        is_symbol_show=False,
        linestyle_opts=opts.LineStyleOpts(width=LINE_WIDTH),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="Gold (GLD) Price and Moving Averages"),
        xaxis_opts=opts.AxisOpts(type_="category",
                                 axislabel_opts=opts.LabelOpts(rotate=45)),
        yaxis_opts=opts.AxisOpts(name="Price"),
        tooltip_opts=opts.TooltipOpts(trigger="axis"),
        legend_opts=opts.LegendOpts(pos_top="5%"),
        datazoom_opts=[opts.DataZoomOpts(), opts.DataZoomOpts(type_="inside")],
    )
)
line_gold.render("result/gold_ma_bt.html")


# ================== 6. 持仓分布饼图 ==================
signal_counts = {}
for s in strat.signal_list:
    signal_counts[s] = signal_counts.get(s, 0) + 1

labels = []
values = []
for s, cnt in signal_counts.items():
    if s == 1:
        labels.append("Strategy A")
    elif s == 2:
        labels.append("Strategy B")
    else:
        labels.append("Cash")
    values.append(cnt)

pie_hold = (
    Pie(init_opts=opts.InitOpts(width="700px", height="700px"))
    .add("", [list(z) for z in zip(labels, values)])
    .set_global_opts(
        title_opts=opts.TitleOpts(title="Strategy Holding Distribution"),
        legend_opts=opts.LegendOpts(
            orient="vertical",
            pos_top="15%", pos_left="2%"
        ),
    )
    .set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {d}%"))
)
pie_hold.render("result/holding_distribution_bt.html")


'/home/luany/develop/201_nsdq_avoid/result/holding_distribution_bt.html'