In [16]:
import pandas as pd
import backtrader as bt
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# -------------------------- 1. 数据预处理 --------------------------
def preprocess_data(data_path):
    # 读取数据（支持parquet和csv格式）
    if data_path.endswith('.parquet'):
        df = pd.read_parquet(data_path, engine="pyarrow")
    else:
        df = pd.read_csv(data_path)
    
    # 必要字段检查
    needed_cols = [
        "date", "stock_code", "open", "close", "volume", "pre_close",
        "high_limit", "low_limit", "paused"
    ]
    missing_cols = [col for col in needed_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"数据缺少必要字段：{missing_cols}")
    
    # 日期处理
    df["date"] = pd.to_datetime(df["date"])
    df = df[(df["date"] >= "2023-01-01") & (df["date"] <= "2025-10-27")]
    
    # 股票代码统一格式
    df["stock_code"] = df["stock_code"].astype(str).str.strip()
    valid_stocks = df["stock_code"].unique()  # 取前50只股票
    df = df[df["stock_code"].isin(valid_stocks)]
    
    # 数据清洗
    df = df[
        (df["volume"] > 100) &
        (df["open"] > 0.01) &
        (df["close"] > 0.01)
    ].reset_index(drop=True)
    
    # 填充停牌字段
    df["paused"] = df["paused"].fillna(1.0).astype("float32")
    
    # 计算选股指标
    df["price_change"] = (df["close"] - df["pre_close"]) / df["pre_close"] * 100
    df["is_limit_up"] = (df["close"] >= df["high_limit"] * 0.99) & (df["price_change"] >= 9.0)
    df["vol_3d_avg"] = df.groupby("stock_code")["volume"].rolling(window=3, min_periods=1).mean().reset_index(level=0, drop=True)
    df["volume_ratio"] = df["volume"] / df["vol_3d_avg"].clip(lower=1)
    
    # 预计算每日符合条件股票
    daily_eligible = {}
    market_days = sorted(df["date"].unique())
    for date in market_days:
        date_df = df[df["date"] == date]
        if len(date_df) < 2:
            daily_eligible[date] = []
            continue
        eligible = date_df[
            (date_df["volume_ratio"] > 0.8) &
            ((date_df["is_limit_up"]) | (date_df["price_change"] > 1))
        ].sort_values("price_change", ascending=False)["stock_code"].tolist()
        daily_eligible[date] = eligible[:4]
    
    # 生成Backtrader数据源
    class CustomData(bt.feeds.PandasData):
        lines = ("paused", "price_change", "volume_ratio")
        params = (
            ("paused", "paused"),
            ("price_change", "price_change"),
            ("volume_ratio", "volume_ratio"),
            ("openinterest", -1)
        )
    
    data_feeds = {}
    for code in tqdm(valid_stocks, desc="生成数据源"):
        stock_df = df[df["stock_code"] == code].sort_values("date").set_index("date")
        stock_df = stock_df.reindex(market_days).fillna({
            "open": 0, "close": 0, "volume": 0, "paused": 1.0,
            "price_change": 0, "volume_ratio": 0
        })
        data_feeds[code] = CustomData(dataname=stock_df)
    
    return data_feeds, daily_eligible, market_days, df

# -------------------------- 2. 策略定义（修复params格式错误） --------------------------
class SimpleYangJiaStrategy(bt.Strategy):
    # 核心修复：params必须是二元组的元组
    params = (
        ("daily_eligible", None),
        ("debug_df", None)  # 正确格式：(参数名, 默认值)
    )
    
    def __init__(self):
        self.trade_log = []
        self.hold_days = {}
        self.pbar = tqdm(total=len(self.datas[0]), desc="回测进度")
        self.daily_eligible_count = []
    
    def next(self):
        self.pbar.update(1)
        current_date = self.data.datetime.date(0)
        current_cash = self.broker.getcash()
        
        # 卖出逻辑
        for code in list(self.hold_days.keys()):
            data = self.getdatabyname(code)
            if not data:
                del self.hold_days[code]
                continue
            if self.getposition(data).size > 0 and self.hold_days[code] >= 2 and data.close[0] > 0.01:
                self.sell(data=data, price=data.close[0], size=self.getposition(data).size)
                self.trade_log.append(f"{current_date} 卖出 {code} | 价格：{data.close[0]:.2f} | 仓位：{self.getposition(data).size}")
                del self.hold_days[code]
        
        # 更新持仓天数
        for code in self.hold_days.keys():
            self.hold_days[code] += 1
        
        # 买入逻辑
        if len(self.positions) == 0 and current_cash > 1000:
            eligible_codes = self.p.daily_eligible.get(current_date, [])
            self.daily_eligible_count.append(f"{current_date} 符合条件股票数：{len(eligible_codes)}")
            
            if not eligible_codes:
                return
            
            valid_codes = []
            for code in eligible_codes:
                data = self.getdatabyname(code)
                if not data:
                    continue
                debug_info = f"{current_date} {code} | paused：{data.paused[0]} | open：{data.open[0]}"
                self.trade_log.append(debug_info)
                
                if data.paused[0] == 0.0 and data.open[0] > 0.01:
                    valid_codes.append(code)
            
            self.trade_log.append(f"{current_date} 有效可买股票数：{len(valid_codes)}")
            
            if len(valid_codes) == 0:
                return
            
            buy_cash_per_stock = current_cash / len(valid_codes)
            for code in valid_codes:
                data = self.getdatabyname(code)
                if buy_cash_per_stock < data.open[0] * 100:
                    continue
                size = int(buy_cash_per_stock // (data.open[0] * 100) * 100)
                if size <= 0:
                    continue
                self.buy(data=data, price=data.open[0], size=size)
                self.hold_days[code] = 1
                self.trade_log.append(f"{current_date} 买入 {code} | 价格：{data.open[0]:.2f} | 数量：{size}")
    
    def stop(self):
        self.pbar.close()
        print("\n前20条调试日志：")
        for log in self.trade_log[:20]:
            print(log)

# -------------------------- 3. 回测运行 --------------------------
def run_backtest(data_feeds, daily_eligible, debug_df):
    cerebro = bt.Cerebro(stdstats=True, runonce=False, exactbars=1)
    cerebro.broker.setcash(1000000.0)
    cerebro.broker.setcommission(commission=0.0003)
    cerebro.broker.set_slippage_perc(perc=0.0005)
    
    for code, data in data_feeds.items():
        cerebro.adddata(data, name=code)
    
    cerebro.addstrategy(SimpleYangJiaStrategy, daily_eligible=daily_eligible, debug_df=debug_df)
    
    cerebro.addanalyzer(bt.analyzers.Returns, _name="returns", tann=252)
    cerebro.addanalyzer(bt.analyzers.DrawDown, _name="drawdown")
    cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name="trades")
    
    print(f"回测开始 | 初始资金：{cerebro.broker.getvalue():.2f}元")
    results = cerebro.run()
    final_cash = cerebro.broker.getvalue()
    print(f"回测结束 | 最终资金：{final_cash:.2f}元")
    return results, final_cash

# -------------------------- 4. 结果分析 --------------------------
def analyze_results(results, final_cash):
    if not results:
        print("回测未生成结果")
        return
    strat = results[0]
    returns = strat.analyzers.returns.get_analysis()
    drawdown = strat.analyzers.drawdown.get_analysis()
    trades = strat.analyzers.trades.get_analysis()
    
    total_return = returns.get("rtot", 0) * 100 if "rtot" in returns else 0
    annual_return = returns.get("rnorm100", 0) if "rnorm100" in returns else 0
    max_dd = drawdown.get("max", {}).get("drawdown", 0) if "max" in drawdown else 0
    total_trades = trades.get("total", {}).get("total", 0) if "total" in trades else 0
    win_trades = trades.get("won", {}).get("total", 0) if "won" in trades else 0
    
    print("\n" + "="*60)
    print("回测核心指标")
    print("="*60)
    print(f"初始资金：1000000.00元")
    print(f"最终资金：{final_cash:.2f}元")
    print(f"累计收益率：{total_return:.2f}%")
    print(f"年化收益率：{annual_return:.2f}%")
    print(f"最大回撤：{max_dd:.2f}%")
    print(f"总交易次数：{total_trades}")
    if total_trades > 0:
        print(f"胜率：{win_trades/total_trades*100:.2f}%")

# -------------------------- 5. 主函数 --------------------------
if __name__ == "__main__":
    # 请替换为你的数据路径
    DATA_PATH = "D:\\workspace\\xiaoyao\\data\\widetable.parquet"  # 或 "widetable.csv"
    
    print("开始数据预处理...")
    data_feeds, daily_eligible, market_days, debug_df = preprocess_data(DATA_PATH)
    print(f"数据预处理完成 | 有效股票数：{len(data_feeds)} | 交易日数：{len(market_days)}")
    
    if len(data_feeds) == 0:
        print("无有效数据源，终止回测")
    else:
        print("\n开始回测...")
        results, final_cash = run_backtest(data_feeds, daily_eligible, debug_df)
        print("\n开始结果分析...")
        analyze_results(results, final_cash)

开始数据预处理...


生成数据源: 100%|██████████| 5285/5285 [23:02<00:00,  3.82it/s]


数据预处理完成 | 有效股票数：5285 | 交易日数：680

开始回测...
回测开始 | 初始资金：1000000.00元


回测进度: 680it [30:06,  2.66s/it]



前20条调试日志：
回测结束 | 最终资金：1000000.00元

开始结果分析...

回测核心指标
初始资金：1000000.00元
最终资金：1000000.00元
累计收益率：0.00%
年化收益率：0.00%
最大回撤：0.00%
总交易次数：0
