In [None]:
import pandas as pd
import numpy as np
import akshare as ak  # 使用akshare替代tushare
import datetime
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def calculate_sma(data, window):
    """计算简单移动平均线"""
    return data.rolling(window=window).mean()

def calculate_ema(data, window):
    """计算指数移动平均线"""
    return data.ewm(span=window, adjust=False).mean()

def calculate_rsi(data, window=14):
    """计算RSI (相对强弱指数)"""
    delta = data.diff()
    gain = delta.where(delta > 0, 0)
    loss = -delta.where(delta < 0, 0)
    avg_gain = gain.ewm(span=window, adjust=False).mean()
    avg_loss = loss.ewm(span=window, adjust=False).mean()
    rs = avg_gain / avg_loss
    rsi = 100 - (100 / (1 + rs))
    return rsi

def calculate_atr(high, low, close, window=14):
    """计算ATR (平均真实波幅)"""
    tr1 = high - low
    tr2 = abs(high - close.shift())
    tr3 = abs(low - close.shift())
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
    atr = tr.ewm(span=window, adjust=False).mean()
    return atr

def calculate_adx(high, low, close, window=14):
    """
    计算ADX (平均趋向指数)
    这是一个简化的ADX计算，实际ADX计算更复杂，通常使用TA-Lib
    """
    # True Range
    tr1 = high - low
    tr2 = abs(high - close.shift(1))
    tr3 = abs(low - close.shift(1))
    tr = pd.DataFrame({'tr1': tr1, 'tr2': tr2, 'tr3': tr3}).max(axis=1)

    # Directional Movement
    plus_dm = high - high.shift(1)
    minus_dm = low.shift(1) - low

    plus_dm = plus_dm.where(plus_dm > minus_dm, 0).where(plus_dm > 0, 0)
    minus_dm = minus_dm.where(minus_dm > plus_dm, 0).where(minus_dm > 0, 0)

    # Smoothed True Range and Directional Movement
    atr = tr.ewm(span=window, adjust=False).mean()
    plus_di = (plus_dm.ewm(span=window, adjust=False).mean() / atr) * 100
    minus_di = (minus_dm.ewm(span=window, adjust=False).mean() / atr) * 100

    # DX and ADX
    dx = abs(plus_di - minus_di) / (plus_di + minus_di) * 100
    adx = dx.ewm(span=window, adjust=False).mean()
    return adx, plus_di, minus_di

def standardize_series(series):
    """Z-score标准化"""
    return (series - series.mean()) / series.std()

In [None]:
def get_stock_data(ts_code, start_date, end_date):
    """
    使用akshare获取股票数据
    """
    try:
        # 解析股票代码，akshare需要代码格式如 '000001'
        code_only = ts_code.split('.')[0]
        
        # 使用akshare获取股票日线数据
        df = ak.stock_zh_a_hist(symbol=code_only, period="daily", 
                               start_date=start_date.replace('-', ''), 
                               end_date=end_date.replace('-', ''), 
                               adjust="qfq")
        
        if df.empty:
            print(f"警告: 未获取到股票 {ts_code} 的数据")
            return None
            
        # 重命名列以匹配原有代码
        df = df.rename(columns={
            '日期': 'trade_date',
            '开盘': 'open',
            '最高': 'high', 
            '最低': 'low',
            '收盘': 'close',
            '成交量': 'vol',
            '成交额': 'amount'
        })
        
        # 转换日期格式
        df['trade_date'] = pd.to_datetime(df['trade_date'])
        df = df.set_index('trade_date').sort_index()
        
        # 选择需要的列
        df = df[['open', 'high', 'low', 'close', 'vol', 'amount']]
        
        return df
        
    except Exception as e:
        print(f"获取股票 {ts_code} 数据时出错: {e}")
        return None

def get_index_data(index_code, start_date, end_date):
    """
    使用akshare获取指数数据
    """
    try:
        # 解析指数代码，akshare需要格式如 '000300'
        code_only = index_code.split('.')[0]
        
        # 使用akshare获取指数日线数据
        df = ak.stock_zh_index_daily(symbol=f"sh{code_only}" if index_code.endswith('.SH') else f"sz{code_only}")
        
        if df.empty:
            print(f"警告: 未获取到指数 {index_code} 的数据")
            return None
            
        # 重命名列
        df = df.rename(columns={
            'date': 'trade_date',
            'close': 'close'
        })
        
        # 转换日期格式
        df['trade_date'] = pd.to_datetime(df['trade_date'])
        df = df.set_index('trade_date').sort_index()
        
        # 筛选日期范围
        df = df.loc[start_date:end_date]
        
        return df[['close']]
        
    except Exception as e:
        print(f"获取指数 {index_code} 数据时出错: {e}")

In [None]:
class AdaptiveMultiFactorStrategy:
    def __init__(self, stock_codes, start_date, end_date, initial_cash=1_000_000):
        self.stock_codes = stock_codes
        self.start_date = start_date
        self.end_date = end_date
        self.initial_cash = initial_cash
        self.cash = initial_cash
        self.positions = {code: 0 for code in stock_codes}  # {ts_code: shares}
        self.portfolio_value_history = []
        self.trade_log = []
        self.holdings_value = {code: 0 for code in stock_codes}
        self.stock_data = {}
        self.index_data = None
        self.market_regime_history = []
        self.current_market_regime = 'UNKNOWN'

        # 策略参数
        self.adx_window = 14
        self.adx_trend_threshold = 25  # ADX高于此值视为有趋势
        self.rsi_window = 14
        self.momentum_window = 60  # 动量计算周期
        self.rebalance_freq = 'W'  # 'D' for daily, 'W' for weekly, 'M' for monthly
        self.num_holdings = 5  # 持仓股票数量
        self.stop_loss_atr_multiplier = 2.0  # 止损ATR倍数
        self.take_profit_atr_multiplier = 3.0  # 止盈ATR倍数
        self.trailing_stop_atr_multiplier = 1.0  # 追踪止损ATR倍数

        # 因子权重 (根据市场状态动态调整)
        self.weights = {
            'trend': {'momentum': 0.5, 'quality': 0.3, 'value': 0.2, 'reversion': 0},
            'sideways': {'reversion': 0.5, 'quality': 0.3, 'value': 0.2, 'momentum': 0},
            'reversal': {'reversion': 0.6, 'quality': 0.2, 'value': 0.2, 'momentum': 0}
        }
        
        # 交易成本 (简化)
        self.commission_rate = 0.0003  # 万分之三
        self.stamp_duty_rate = 0.001  # 千分之一 (卖出时)
        self.slippage_rate = 0.0001  # 万分之一

    def _load_data(self):
        """加载所有股票和指数数据"""
        print("Loading stock and index data using akshare...")
        
        # 首先加载所有股票数据
        valid_stocks = []
        for code in self.stock_codes:
            df = get_stock_data(code, self.start_date, self.end_date)
            if df is not None and not df.empty:
                self.stock_data[code] = df
                valid_stocks.append(code)
                print(f"Loaded data for {code}")
            else:
                print(f"Failed to load data for {code}")
        
        if not valid_stocks:
            raise Exception("没有成功加载任何股票数据")
            
        self.stock_codes = valid_stocks  # 更新为有效股票代码
        self.positions = {code: 0 for code in valid_stocks}
        self.holdings_value = {code: 0 for code in valid_stocks}

        # 获取所有数据的共同日期
        self.all_dates = None
        for code in self.stock_codes:
            if self.all_dates is None:
                self.all_dates = self.stock_data[code].index
            else:
                self.all_dates = self.all_dates.intersection(self.stock_data[code].index)
        
        # 加载指数数据
        self.index_data = get_index_data('000300.SH', self.start_date, self.end_date)
        if self.index_data is not None:
            self.all_dates = self.all_dates.intersection(self.index_data.index)
        
        self.all_dates = self.all_dates.sort_values()

        # 过滤所有数据到共同日期
        for code in self.stock_codes:
            self.stock_data[code] = self.stock_data[code].loc[self.all_dates]
        
        if self.index_data is not None:
            self.index_data = self.index_data.loc[self.all_dates]
        
        print(f"数据加载完成，共{len(self.all_dates)}个交易日，{len(self.stock_codes)}只股票")

    def _calculate_market_regime(self, current_date):
        """
        计算市场状态 (趋势强度和波动率)
        这里简化为基于ADX判断，实际可加入波动率指标
        """
        if self.index_data is None or len(self.index_data.loc[:current_date]) < self.adx_window * 2:
            return 'UNKNOWN'

        index_close = self.index_data.loc[:current_date]['close']
        # 对于指数，使用close作为high和low的近似值
        index_high = index_close
        index_low = index_close

        adx, plus_di, minus_di = calculate_adx(index_high, index_low, index_close, self.adx_window)
        
        if len(adx) == 0:
            return 'UNKNOWN'
            
        if adx.iloc[-1] > self.adx_trend_threshold:
            if plus_di.iloc[-1] > minus_di.iloc[-1]:
                return 'trend'  # Up-trend
            elif minus_di.iloc[-1] > plus_di.iloc[-1]:
                return 'trend'  # Down-trend

        return 'sideways'  # Default to sideways if not strong trend

    def _calculate_factors(self, code, current_date):
        """计算个股因子"""
        df = self.stock_data[code].loc[:current_date]
        if len(df) < max(self.momentum_window, self.rsi_window, self.adx_window) + 1:
            return None  # Not enough data

        # 动量因子 (Momentum Factor)
        momentum = (df['close'].iloc[-1] / df['close'].iloc[-self.momentum_window] - 1) if len(df) >= self.momentum_window else np.nan
        
        # 相对动量 (与指数比较)
        index_return = np.nan
        if self.index_data is not None:
            lookback_date = current_date - pd.Timedelta(days=self.momentum_window)
            if lookback_date in self.index_data.index and current_date in self.index_data.index:
                index_return = (self.index_data.loc[current_date]['close'] / 
                              self.index_data.loc[lookback_date]['close'] - 1)
        
        relative_momentum = momentum - index_return if not np.isnan(momentum) and not np.isnan(index_return) else np.nan

        # 反转因子 (Reversion Factor)
        rsi = calculate_rsi(df['close'], self.rsi_window).iloc[-1]
        reversion = 100 - rsi if not np.isnan(rsi) else np.nan  # 100-RSI, higher means more oversold

        # 质量因子 (Quality Factor) - 模拟，实际需要财务数据
        quality_score = np.random.rand()  # Placeholder
        
        # 价值因子 (Value Factor) - 模拟，实际需要财务数据
        value_score = np.random.rand()  # Placeholder

        return {
            'momentum': relative_momentum,
            'reversion': reversion,
            'quality': quality_score,
            'value': value_score
        }

    def _calculate_composite_score(self, factors, market_regime):
        """根据市场状态和因子计算综合评分"""
        if factors is None:
            return -np.inf  # Return a very low score for invalid data

        weights = self.weights.get(market_regime, self.weights['sideways'])  # Default to sideways
        
        score = 0
        if 'momentum' in factors and not np.isnan(factors['momentum']):
            score += weights['momentum'] * factors['momentum']
        if 'reversion' in factors and not np.isnan(factors['reversion']):
            score += weights['reversion'] * factors['reversion']
        if 'quality' in factors and not np.isnan(factors['quality']):
            score += weights['quality'] * factors['quality']
        if 'value' in factors and not np.isnan(factors['value']):
            score += weights['value'] * factors['value']
            
        return score

    def _execute_trade(self, date, code, trade_type, shares, price):
        """执行交易并记录"""
        if shares <= 0:
            return

        cost = shares * price
        commission = cost * self.commission_rate
        slippage = cost * self.slippage_rate  # 模拟滑点
        
        total_cost = cost + commission + slippage
        
        if trade_type == 'BUY':
            if self.cash >= total_cost:
                self.cash -= total_cost
                self.positions[code] += shares
                self.trade_log.append({
                    'date': date, 'code': code, 'type': 'BUY',
                    'shares': shares, 'price': price, 'cost': total_cost,
                    'cash_left': self.cash
                })
            else:
                pass
        elif trade_type == 'SELL':
            stamp_duty = cost * self.stamp_duty_rate
            total_revenue = cost - commission - slippage - stamp_duty
            
            if self.positions[code] >= shares:
                self.cash += total_revenue
                self.positions[code] -= shares
                self.trade_log.append({
                    'date': date, 'code': code, 'type': 'SELL',
                    'shares': shares, 'price': price, 'revenue': total_revenue,
                    'cash_left': self.cash
                })
            else:
                pass

    def _get_current_portfolio_value(self, date):
        """计算当前总资产"""
        current_value = self.cash
        for code, shares in self.positions.items():
            if shares > 0 and date in self.stock_data[code].index:
                current_value += shares * self.stock_data[code].loc[date]['close']
        return current_value

    def _apply_stop_loss_take_profit(self, date):
        """
        应用动态止盈止损
        这里简化为基于买入价和ATR，实际需要记录每笔交易的买入价和ATR
        """
        for code, shares in list(self.positions.items()):  # Iterate over a copy
            if shares > 0 and date in self.stock_data[code].index:
                df_hist = self.stock_data[code].loc[:date]
                if len(df_hist) < self.adx_window:  # Need enough data for ATR
                    continue

                current_close = df_hist['close'].iloc[-1]
                
                # Find the original buy price for this position (simplified: use first buy price)
                buy_price = None
                for trade in self.trade_log:
                    if trade['code'] == code and trade['type'] == 'BUY':
                        buy_price = trade['price']
                        break
                
                if buy_price is None:  # Should not happen if position > 0
                    continue

                current_atr = calculate_atr(df_hist['high'], df_hist['low'], df_hist['close'], self.adx_window).iloc[-1]
                
                # Stop Loss
                stop_loss_price = buy_price - self.stop_loss_atr_multiplier * current_atr
                if current_close < stop_loss_price:
                    self._execute_trade(date, code, 'SELL', shares, current_close)
                    continue
                
                # Take Profit
                take_profit_price = buy_price + self.take_profit_atr_multiplier * current_atr
                if current_close > take_profit_price:
                    self._execute_trade(date, code, 'SELL', shares, current_close)
                    continue

    def run_backtest(self):
        self._load_data()
        
        if len(self.all_dates) == 0:
            print("错误: 没有可用的交易日数据")
            return pd.DataFrame()
            
        last_rebalance_date = self.all_dates[0] - pd.Timedelta(days=1)  # Ensure first rebalance happens
        
        for date in self.all_dates:
            # 1. 更新市场状态
            self.current_market_regime = self._calculate_market_regime(date)
            self.market_regime_history.append({'date': date, 'regime': self.current_market_regime})

            # 2. 应用止盈止损 (每日检查)
            self._apply_stop_loss_take_profit(date)

            # 3. 再平衡 (根据频率)
            rebalance = False
            if self.rebalance_freq == 'D':
                rebalance = True
            elif self.rebalance_freq == 'W' and date.weekday() == 0:  # 每周一
                rebalance = True
            elif self.rebalance_freq == 'M' and date.day == 1:  # 每月第一天
                rebalance = True
                
            if rebalance and date > last_rebalance_date:
                last_rebalance_date = date
                
                # 计算所有股票的综合评分
                scores = {}
                for code in self.stock_codes:
                    factors = self._calculate_factors(code, date)
                    if factors:
                        scores[code] = self._calculate_composite_score(factors, self.current_market_regime)
                
                # 筛选出排名前N的股票
                valid_scores = {k: v for k, v in scores.items() if v != -np.inf}
                if valid_scores:
                    sorted_scores = sorted(valid_scores.items(), key=lambda item: item[1], reverse=True)
                    top_n_stocks = [code for code, score in sorted_scores[:self.num_holdings]]

                    # 卖出不在目标持仓中的股票
                    for code, shares in list(self.positions.items()):
                        if shares > 0 and code not in top_n_stocks:
                            self._execute_trade(date, code, 'SELL', shares, self.stock_data[code].loc[date]['close'])
                    
                    # 买入目标持仓中的股票 (等权重分配)
                    if top_n_stocks:
                        target_cash_per_stock = self.cash / len(top_n_stocks)
                        for code in top_n_stocks:
                            if self.positions[code] == 0:  # Only buy if not holding
                                current_price = self.stock_data[code].loc[date]['close']
                                if current_price > 0:
                                    shares_to_buy = int(target_cash_per_stock / current_price / 100) * 100  # 按手买入
                                    if shares_to_buy > 0:
                                        self._execute_trade(date, code, 'BUY', shares_to_buy, current_price)

            # 记录每日总资产
            self.portfolio_value_history.append({
                'date': date,
                'value': self._get_current_portfolio_value(date)
            })
            
        return pd.DataFrame(self.portfolio_value_history).set_index('date')

In [None]:
def evaluate_strategy(portfolio_df, initial_cash, benchmark_df):
    """评估策略表现"""
    if portfolio_df.empty:
        return {}, pd.Series()
        
    portfolio_df['returns'] = portfolio_df['value'].pct_change()
    
    if benchmark_df is not None and not benchmark_df.empty:
        benchmark_df['returns'] = benchmark_df['close'].pct_change()
    else:
        # 创建空的benchmark数据
        benchmark_df = pd.DataFrame(index=portfolio_df.index)
        benchmark_df['close'] = initial_cash
        benchmark_df['returns'] = 0

    # 总收益
    total_return = (portfolio_df['value'].iloc[-1] / initial_cash) - 1
    benchmark_total_return = (benchmark_df['close'].iloc[-1] / benchmark_df['close'].iloc[0]) - 1

    # 年化收益 (假设252个交易日)
    days = (portfolio_df.index[-1] - portfolio_df.index[0]).days
    annualized_return = (1 + total_return)**(252 / days) - 1 if days > 0 else 0
    benchmark_annualized_return = (1 + benchmark_total_return)**(252 / days) - 1 if days > 0 else 0

    # 波动率
    volatility = portfolio_df['returns'].std() * np.sqrt(252)
    benchmark_volatility = benchmark_df['returns'].std() * np.sqrt(252)

    # 最大回撤
    peak = portfolio_df['value'].expanding(min_periods=1).max()
    drawdown = (portfolio_df['value'] - peak) / peak
    max_drawdown = drawdown.min()

    benchmark_peak = benchmark_df['close'].expanding(min_periods=1).max()
    benchmark_drawdown = (benchmark_df['close'] - benchmark_peak) / benchmark_peak
    benchmark_max_drawdown = benchmark_drawdown.min()

    # 夏普比率 (假设无风险利率为0.02)
    risk_free_rate = 0.02
    sharpe_ratio = (annualized_return - risk_free_rate) / volatility if volatility != 0 else np.nan
    benchmark_sharpe_ratio = (benchmark_annualized_return - risk_free_rate) / benchmark_volatility if benchmark_volatility != 0 else np.nan

    metrics = {
        'Total Return': f"{total_return:.2%}",
        'Annualized Return': f"{annualized_return:.2%}",
        'Volatility (Annualized)': f"{volatility:.2%}",
        'Max Drawdown': f"{max_drawdown:.2%}",
        'Sharpe Ratio': f"{sharpe_ratio:.2f}",
        'Benchmark Total Return': f"{benchmark_total_return:.2%}",
        'Benchmark Annualized Return': f"{benchmark_annualized_return:.2%}",
        'Benchmark Volatility (Annualized)': f"{benchmark_volatility:.2%}",
        'Benchmark Max Drawdown': f"{benchmark_max_drawdown:.2%}",
        'Benchmark Sharpe Ratio': f"{benchmark_sharpe_ratio:.2f}",
    }
    return metrics, drawdown

In [None]:
if __name__ == "__main__":
    # 使用真实股票代码
    stock_codes = ['000001.SZ', '600000.SH', '600036.SH', '000002.SZ', '601318.SH',
                   '600519.SH', '000858.SZ', '000333.SZ', '002415.SZ', '300750.SZ']

    start_date = '2020-01-01'
    end_date = '2023-12-31'
    initial_cash = 1_000_000  # 100万初始资金

    strategy = AdaptiveMultiFactorStrategy(stock_codes, start_date, end_date, initial_cash)
    portfolio_value_df = strategy.run_backtest()

    if portfolio_value_df.empty:
        print("回测失败，没有生成有效数据")
    else:
        # 获取基准数据
        benchmark_data = get_index_data('000300.SH', start_date, end_date)
        
        # 确保基准数据与策略回测日期对齐
        if benchmark_data is not None:
            benchmark_data = benchmark_data.loc[portfolio_value_df.index]
        else:
            benchmark_data = pd.DataFrame(index=portfolio_value_df.index)
            benchmark_data['close'] = initial_cash

        # 评估策略
        metrics, drawdown_history = evaluate_strategy(portfolio_value_df, initial_cash, benchmark_data)

        print("\n--- 策略评估结果 ---")
        for key, value in metrics.items():
            print(f"{key}: {value}")

        # 绘制净值曲线和回撤
        plt.figure(figsize=(14, 8))

        # 净值曲线
        plt.subplot(2, 1, 1)
        plt.plot(portfolio_value_df.index, portfolio_value_df['value'], label='Strategy Portfolio Value', color='blue')
        # 将基准指数归一化到与策略初始资金相同
        if benchmark_data is not None and not benchmark_data.empty:
            benchmark_normalized = benchmark_data['close'] / benchmark_data['close'].iloc[0] * initial_cash
            plt.plot(benchmark_normalized.index, benchmark_normalized, label='Benchmark (HS300) Normalized', color='orange', linestyle='--')
        plt.title('Strategy vs. Benchmark Portfolio Value')
        plt.xlabel('Date')
        plt.ylabel('Portfolio Value')
        plt.legend()
        plt.grid(True)

        # 最大回撤
        plt.subplot(2, 1, 2)
        plt.fill_between(drawdown_history.index, drawdown_history, 0, color='red', alpha=0.3, label='Strategy Drawdown')
        if benchmark_data is not None and not benchmark_data.empty:
            benchmark_peak = benchmark_data['close'].expanding(min_periods=1).max()
            benchmark_drawdown = (benchmark_data['close'] - benchmark_peak) / benchmark_peak
            plt.fill_between(benchmark_data.index, benchmark_drawdown, 0, color='grey', alpha=0.3, label='Benchmark Drawdown')
        plt.title('Portfolio Drawdown')
        plt.xlabel('Date')
        plt.ylabel('Drawdown (%)')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        print("\n--- 市场状态历史 ---")
        market_regime_df = pd.DataFrame(strategy.market_regime_history).set_index('date')
        print(market_regime_df['regime'].value_counts())

        # 可视化市场状态
        plt.figure(figsize=(14, 4))
        market_regime_df['regime_code'] = market_regime_df['regime'].astype('category').cat.codes
        plt.plot(market_regime_df.index, market_regime_df['regime_code'], marker='.', linestyle='None', alpha=0.5)
        plt.yticks(ticks=market_regime_df['regime_code'].unique(), labels=market_regime_df['regime'].unique())
        plt.title('Market Regime Over Time')
        plt.xlabel('Date')
        plt.ylabel('Market Regime')
        plt.grid(True)
        plt.tight_layout()
        plt.show()