In [1]:
import backtrader as bt
import backtrader.analyzers as btanalyzers
import pandas as pd
import yfinance as yf
import akshare as ak
from datetime import datetime, timedelta
import os
import numpy as np
import warnings
import time

warnings.filterwarnings('ignore')

# 定义资产
intl_etfs = {
    'QQQ': '纳斯达克ETF',
    'SPY': '标普500ETF'
}

china_etfs = {
    '159980': '有色金属etf',
    '513010': '恒生科技',
    '159892': '恒生医药',
    '159934': '黄金etf',
    '159985': '豆粕etf',
    '510880': '红利ETF',
    '516780': '稀土',
}

# 资产名称映射
asset_name_map = {
    '159980': '有色金属ETF',
    '513010': '恒生科技ETF',
    '159892': '恒生医药ETF',
    '159934': '黄金ETF',
    '159985': '豆粕ETF',
    '510880': '红利ETF',
    '516780': '稀土ETF',
    'QQQ': '纳斯达克ETF',
    'SPY': '标普500ETF'
}

class QuarterlyContrarianRotation(bt.Strategy):
    params = (
        ('selection_count', 4),
        ('rebalance_months', [1, 4, 7, 10]),
        ('quarter_trading_days', 63),  # 大约一个季度的交易日数
    )

    def __init__(self):
        self.assets = self.datas
        self.asset_names = [d._name for d in self.datas]
        print(f"策略初始化完成，可交易资产: {[asset_name_map.get(name, name) for name in self.asset_names]}")
        
        self.trade_log = []
        self.quarterly_returns = {}
        self.current_year = None
        self.year_start_value = None
        self.annual_returns = {}
        self.last_rebalance = None

    def next(self):
        dt = self.datas[0].datetime.date(0)
        
        # 初始化年度回报
        if self.current_year is None:
            self.current_year = dt.year
            self.year_start_value = self.broker.getvalue()
            self.annual_returns[self.current_year] = 0
        
        # 年度更替
        if dt.year != self.current_year:
            self.current_year = dt.year
            self.year_start_value = self.broker.getvalue()
            self.annual_returns[self.current_year] = 0
        
        # 更新年度回报
        self.annual_returns[self.current_year] = (self.broker.getvalue() / self.year_start_value - 1) * 100
        
        # 检查是否是调仓月份
        if dt.month in self.params.rebalance_months:
            # 确保有足够的历史数据
            if len(self.data0) < self.params.quarter_trading_days + 5:
                return
                
            # 检查是否是该月第一次出现
            if self.last_rebalance is None or self.last_rebalance.month != dt.month:
                self.rebalance_portfolio(dt)
                self.last_rebalance = dt

    def rebalance_portfolio(self, dt):
        print(f"\n=== {dt} 季度调仓 ===")
        
        # 计算上个季度的收益率（使用固定交易日间隔）
        quarter_returns = {}
        lookback = self.params.quarter_trading_days
        
        for data in self.assets:
            asset_name = data._name
            display_name = asset_name_map.get(asset_name, asset_name)
            
            # 确保有足够的历史数据
            if len(data) < lookback + 1:
                # print(f"{display_name} 数据不足，跳过")
                continue
                
            try:
                # 获取价格 - 使用固定交易日间隔
                start_price = data.close[-lookback]
                end_price = data.close[0]
                
                if start_price > 0 and end_price > 0:
                    return_pct = (end_price - start_price) / start_price * 100
                    quarter_returns[asset_name] = return_pct
                    # 获取实际日期用于调试
                    start_date = data.datetime.date(-lookback)
                    print(f"{display_name}: {return_pct:.2f}% (日期: {start_date} -> {dt})")
                else:
                    print(f"{display_name} 价格无效: start={start_price}, end={end_price}")
            except Exception as e:
                print(f"{display_name} 数据访问错误: {str(e)}")
        
        if len(quarter_returns) < self.params.selection_count:
            print(f"{dt} 有效资产不足({len(quarter_returns)}), 需要{self.params.selection_count}个, 跳过调仓")
            return
        
        # 按收益率升序排序，选择表现最差的4个
        sorted_assets = sorted(quarter_returns.items(), key=lambda x: x[1])
        selected_assets = [x[0] for x in sorted_assets[:self.params.selection_count]]
        selected_names = [asset_name_map.get(a, a) for a in selected_assets]
        print(f"选中资产: {selected_names}")
        
        self.quarterly_returns[dt] = quarter_returns
        
        # 获取当前持仓
        current_positions = {}
        for data in self.assets:
            pos = self.getposition(data).size
            if pos > 0:
                current_positions[data._name] = pos
        
        total_value = self.broker.getvalue()
        target_per_asset = total_value * 0.25
        
        print(f"总资产: {total_value:.2f}, 目标每资产: {target_per_asset:.2f}")
        
        # 调整仓位
        for data in self.assets:
            asset_name = data._name
            display_name = asset_name_map.get(asset_name, asset_name)
            current_pos = self.getposition(data).size
            current_value = current_pos * data.close[0] if current_pos > 0 else 0
            
            if asset_name in selected_assets:
                if asset_name in current_positions:
                    # 调整仓位
                    diff = target_per_asset - current_value
                    if abs(diff) > 1:
                        if diff > 0:
                            size = diff / data.close[0]
                            self.buy(data, size=size)
                            self.log_trade(dt, asset_name, display_name, 'BUY', size, data.close[0])
                            print(f"增持 {display_name}: {size:.2f} 股")
                        else:
                            size = abs(diff) / data.close[0]
                            self.sell(data, size=size)
                            self.log_trade(dt, asset_name, display_name, 'SELL', size, data.close[0])
                            print(f"减持 {display_name}: {size:.2f} 股")
                else:
                    # 新买入
                    size = target_per_asset / data.close[0]
                    self.buy(data, size=size)
                    self.log_trade(dt, asset_name, display_name, 'BUY', size, data.close[0])
                    print(f"买入 {display_name}: {size:.2f} 股")
            else:
                if asset_name in current_positions:
                    # 清仓
                    self.close(data)
                    self.log_trade(dt, asset_name, display_name, 'CLOSE', current_pos, data.close[0])
                    print(f"清仓 {display_name}: {current_pos:.2f} 股")
    
    def log_trade(self, date, asset, display_name, action, size, price):
        self.trade_log.append({
            'Date': date,
            'Asset': display_name,
            'Action': action,
            'Size': size,
            'Price': price,
            'Value': size * price
        })
    
    def stop(self):
        # 在结束时处理持仓
        dt = self.datas[0].datetime.date(0)
        print(f"\n结束日期 {dt} 处理持仓")
        
        # 清空所有持仓
        for data in self.assets:
            pos = self.getposition(data).size
            if pos > 0:
                try:
                    display_name = asset_name_map.get(data._name, data._name)
                    self.close(data)
                    self.log_trade(dt, data._name, display_name, 'CLOSE', pos, data.close[0])
                    print(f"清仓 {display_name}: {pos:.2f} 股")
                except Exception as e:
                    print(f"清仓 {data._name} 失败: {str(e)}")
        
        # 创建数据框
        self.trade_df = pd.DataFrame(self.trade_log)
        
        # 计算年度回报
        years = sorted(set(int(y) for y in self.annual_returns.keys()))
        annual_returns = [self.annual_returns.get(y, 0) for y in years]
        self.annual_return_df = pd.DataFrame({
            'Year': years,
            'Return(%)': annual_returns
        })
        
        # 季度回报
        quarter_returns = []
        for date, returns in self.quarterly_returns.items():
            row = {'Date': date}
            # 使用显示名称
            for asset, ret in returns.items():
                display_name = asset_name_map.get(asset, asset)
                row[display_name] = ret
            quarter_returns.append(row)
        self.quarterly_return_df = pd.DataFrame(quarter_returns)

def download_intl_etf(ticker, start_date, end_date):
    """使用yfinance下载国际ETF数据"""
    print(f"正在下载国际ETF {ticker} 数据 ({start_date} 至 {end_date})...")
    try:
        # 增加一周缓冲期
        adjusted_start = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=7)).strftime('%Y-%m-%d')
        data = yf.download(ticker, start=adjusted_start, end=end_date, progress=False)
        
        if data.empty:
            print(f"{ticker} 下载数据为空")
            return None
            
        # 规范列名
        data = data.rename(columns={
            'Open': 'open',
            'High': 'high',
            'Low': 'low',
            'Close': 'close',
            'Volume': 'volume'
        })
        
        # 确保所有必需列存在
        required_cols = ['open', 'high', 'low', 'close']
        if not all(col in data.columns for col in required_cols):
            print(f"{ticker} 缺少必要列")
            return None
            
        # 添加volume列如果不存在
        if 'volume' not in data.columns:
            data['volume'] = 0
            
        # 选择并重排列
        data = data[['open', 'high', 'low', 'close', 'volume']]
        
        # 重置索引并填充缺失值
        data = data.reset_index()
        data.fillna(method='ffill', inplace=True)
        data = data.dropna()
        data['Date'] = pd.to_datetime(data['Date'])
        data.set_index('Date', inplace=True)
        
        # 截取指定时间范围
        start_dt = pd.to_datetime(start_date)
        end_dt = pd.to_datetime(end_date)
        data = data.loc[start_dt:end_dt]
        
        print(f"{ticker} 数据下载成功，{len(data)} 条记录，时间范围: {data.index[0].date()} 至 {data.index[-1].date()}")
        return data
    except Exception as e:
        print(f"下载国际ETF {ticker} 失败: {str(e)}")
        return None

def download_china_etf(ticker, start_date, end_date, retry=3):
    """使用akshare下载中国ETF数据，带重试机制"""
    print(f"正在下载中国ETF {ticker} 数据 ({start_date} 至 {end_date})...")
    for attempt in range(retry):
        try:
            # 增加一周缓冲期
            adjusted_start = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=7)).strftime('%Y%m%d')
            adjusted_end = (datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=7)).strftime('%Y%m%d')
            
            # 使用akshare的fund_etf_hist_em接口
            df = ak.fund_etf_hist_em(symbol=ticker, period="daily", 
                                    start_date=adjusted_start, 
                                    end_date=adjusted_end)
            
            if df.empty:
                print(f"{ticker} 下载数据为空")
                return None
                
            # 规范列名
            df = df.rename(columns={
                '日期': 'date',
                '开盘': 'open',
                '收盘': 'close',
                '最高': 'high',
                '最低': 'low',
                '成交量': 'volume'
            })
            
            # 转换日期格式
            df['date'] = pd.to_datetime(df['date'])
            
            # 设置日期索引
            df.set_index('date', inplace=True)
            df.index.name = 'Date'
            
            # 截取指定时间范围
            start_dt = pd.to_datetime(start_date)
            end_dt = pd.to_datetime(end_date)
            df = df.loc[start_dt:end_dt]
            
            if df.empty:
                print(f"{ticker} 在指定时间范围内无数据")
                return None
                
            # 确保数据类型正确
            for col in ['open', 'high', 'low', 'close']:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
            
            # 处理可能的缺失值
            df.dropna(subset=['open', 'high', 'low', 'close'], inplace=True)
            
            # 前向填充缺失值
            df.fillna(method='ffill', inplace=True)
            
            # 选择并重排列
            df = df[['open', 'high', 'low', 'close', 'volume']]
            
            print(f"{ticker} 数据下载成功，{len(df)} 条记录，时间范围: {df.index[0].date()} 至 {df.index[-1].date()}")
            return df
        except Exception as e:
            if attempt < retry - 1:
                wait_time = (attempt + 1) * 5
                print(f"下载 {ticker} 失败，{wait_time}秒后重试... ({str(e)})")
                time.sleep(wait_time)
            else:
                print(f"下载中国ETF {ticker} 失败: {str(e)}")
                return None

def run_backtest():
    # 设置时间范围
    start_date = '2021-07-01'
    end_date = '2025-08-10'
    
    # 下载国际ETF数据
    intl_data = {}
    for ticker in intl_etfs:
        data = download_intl_etf(ticker, start_date, end_date)
        if data is not None:
            intl_data[ticker] = data
    
    # 下载中国ETF数据
    china_data = {}
    for ticker in china_etfs:
        # 对于159892，调整开始日期
        if ticker == '159892':
            adjusted_start_date = '2021-10-01'  # 早于实际开始日期确保有足够数据
        else:
            adjusted_start_date = start_date
            
        data = download_china_etf(ticker, adjusted_start_date, end_date)
        if data is not None:
            china_data[ticker] = data
    
    # 合并所有数据
    data_dict = {**intl_data, **china_data}
    
    if not data_dict:
        print("\n错误: 没有成功下载数据")
        return
    
    print("\n=== 初始化回测引擎 ===")
    cerebro = bt.Cerebro()
    cerebro.broker.set_cash(10000)
    cerebro.broker.setcommission(commission=0.001)
    
    print("\n=== 添加数据 ===")
    for ticker, data in data_dict.items():
        try:
            # 确保列名正确
            data.columns = ['open', 'high', 'low', 'close', 'volume']
            
            # 检查数据长度
            if len(data) < 100:
                print(f"{ticker} 数据长度不足 ({len(data)}), 跳过")
                continue
                
            data_feed = bt.feeds.PandasData(
                dataname=data,
                fromdate=datetime.strptime(start_date, '%Y-%m-%d'),
                todate=datetime.strptime(end_date, '%Y-%m-%d'),
                datetime=None,
                open=0,
                high=1,
                low=2,
                close=3,
                volume=4,
                openinterest=-1
            )
            cerebro.adddata(data_feed, name=ticker)
            print(f"成功添加 {ticker} 数据 ({len(data)} 条记录)")
        except Exception as e:
            print(f"添加 {ticker} 数据失败: {str(e)}")
    
    if not cerebro.datas:
        print("\n错误: 没有有效数据")
        return
    
    print("\n=== 添加策略和分析器 ===")
    cerebro.addstrategy(QuarterlyContrarianRotation)
    cerebro.addanalyzer(btanalyzers.SharpeRatio, _name='sharpe', riskfreerate=0.0, annualize=True)
    cerebro.addanalyzer(btanalyzers.DrawDown, _name='drawdown')
    cerebro.addanalyzer(btanalyzers.TimeReturn, _name='time_return', timeframe=bt.TimeFrame.Years)
    
    print('\n=== 开始回测 ===')
    print(f'初始资金: {cerebro.broker.getvalue():.2f}')
    
    try:
        results = cerebro.run()
    except Exception as e:
        print(f"\n回测错误: {str(e)}")
        return
    
    if not results:
        print("\n回测未产生结果")
        return
    
    strat = results[0]
    print(f'\n最终资金: {cerebro.broker.getvalue():.2f}')
    
    # 计算年化收益率
    start_dt = datetime.strptime(start_date, '%Y-%m-%d')
    end_dt = datetime.strptime(end_date, '%Y-%m-%d')
    total_days = (end_dt - start_dt).days
    years = total_days / 365.25
    total_return = (cerebro.broker.getvalue() / 10000) - 1
    annualized_return = (1 + total_return) ** (1 / years) - 1 if years > 0 else 0
    
    # 计算Alpha (使用SPY作为基准)
    if 'SPY' in data_dict:
        spy_start = data_dict['SPY'].iloc[0]['close']
        spy_end = data_dict['SPY'].iloc[-1]['close']
        spy_return = (spy_end - spy_start) / spy_start if spy_start > 0 else 0
        spy_annualized = (1 + spy_return) ** (1 / years) - 1 if years > 0 else 0
        alpha = annualized_return - spy_annualized
    else:
        alpha = 0
    
    # 获取分析器结果
    sharpe_ratio = strat.analyzers.sharpe.get_analysis().get('sharperatio', 0)
    drawdown = strat.analyzers.drawdown.get_analysis()
    max_drawdown = drawdown.get('max', {}).get('drawdown', 0) / 100
    
    print("\n=== 回测结果 ===")
    print(f"总天数: {total_days}")
    print(f"年化收益率: {annualized_return * 100:.2f}%")
    print(f"夏普比率: {sharpe_ratio:.2f}")
    print(f"最大回撤: {max_drawdown * 100:.2f}%")
    print(f"Alpha: {alpha * 100:.2f}%")
    
    print("\n=== 交易明细 ===")
    if hasattr(strat, 'trade_df'):
        print(strat.trade_df.to_string())
        strat.trade_df.to_csv('trade_details.csv', index=False)
        print("交易明细已保存到 trade_details.csv")
    else:
        print("无交易记录")
    
    print("\n=== 年度回报 ===")
    if hasattr(strat, 'annual_return_df'):
        print(strat.annual_return_df.to_string())
        strat.annual_return_df.to_csv('annual_returns.csv', index=False)
        print("年度回报已保存到 annual_returns.csv")
    else:
        print("无年度回报数据")
    
    print("\n=== 季度资产回报 ===")
    if hasattr(strat, 'quarterly_return_df'):
        print(strat.quarterly_return_df.to_string())
        strat.quarterly_return_df.to_csv('quarterly_returns.csv', index=False)
        print("季度回报已保存到 quarterly_returns.csv")
    else:
        print("无季度回报数据")
    
    print("\n生成回测图表...")
    try:
        cerebro.plot(style='line', iplot=False)
    except Exception as e:
        print(f"绘图失败: {str(e)}")

if __name__ == '__main__':
    run_backtest()

正在下载国际ETF QQQ 数据 (2021-07-01 至 2025-08-10)...
YF.download() has changed argument auto_adjust default to True
QQQ 数据下载成功，1031 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载国际ETF SPY 数据 (2021-07-01 至 2025-08-10)...
SPY 数据下载成功，1031 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载中国ETF 159980 数据 (2021-07-01 至 2025-08-10)...


  0%|          | 0/12 [00:00<?, ?it/s]

159980 数据下载成功，997 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载中国ETF 513010 数据 (2021-07-01 至 2025-08-10)...
513010 数据下载成功，997 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载中国ETF 159892 数据 (2021-10-01 至 2025-08-10)...
159892 数据下载成功，926 条记录，时间范围: 2021-10-19 至 2025-08-08
正在下载中国ETF 159934 数据 (2021-07-01 至 2025-08-10)...
159934 数据下载成功，997 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载中国ETF 159985 数据 (2021-07-01 至 2025-08-10)...
159985 数据下载成功，997 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载中国ETF 510880 数据 (2021-07-01 至 2025-08-10)...
510880 数据下载成功，997 条记录，时间范围: 2021-07-01 至 2025-08-08
正在下载中国ETF 516780 数据 (2021-07-01 至 2025-08-10)...
516780 数据下载成功，997 条记录，时间范围: 2021-07-01 至 2025-08-08

=== 初始化回测引擎 ===

=== 添加数据 ===
成功添加 QQQ 数据 (1031 条记录)
成功添加 SPY 数据 (1031 条记录)
成功添加 159980 数据 (997 条记录)
成功添加 513010 数据 (997 条记录)
成功添加 159892 数据 (926 条记录)
成功添加 159934 数据 (997 条记录)
成功添加 159985 数据 (997 条记录)
成功添加 510880 数据 (997 条记录)
成功添加 516780 数据 (997 条记录)

=== 添加策略和分析器 ===

=== 开始回测 ===
初始资金: 10000.00
策略初始化完成，可交易资产: ['纳斯达克ETF', '标普500ETF', '有色金