In [None]:
# market_state_analyzer.py (新文件)
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler

class MarketStateAnalyzer:
    def __init__(self, lookback_period=120):
        self.lookback_period = lookback_period
        self.model = RandomForestClassifier(n_estimators=100, random_state=42)
        self.scaler = StandardScaler()
        self.market_states = ["Bull", "Bear", "Sideways"]
        
    def extract_market_features(self, market_data):
        """提取市场特征，包括波动率、趋势强度、成交量等"""
        # 计算各种技术指标
        features = pd.DataFrame()
        
        # 短期、中期、长期趋势指标
        features['short_trend'] = market_data['close'].rolling(20).mean() / market_data['close'] - 1
        features['mid_trend'] = market_data['close'].rolling(60).mean() / market_data['close'] - 1
        features['long_trend'] = market_data['close'].rolling(120).mean() / market_data['close'] - 1
        
        # 波动率指标
        features['volatility_20d'] = market_data['close'].rolling(20).std() / market_data['close'].rolling(20).mean()
        features['volatility_60d'] = market_data['close'].rolling(60).std() / market_data['close'].rolling(60).mean()
        
        # 成交量指标
        features['volume_change'] = market_data['volume'].pct_change(20)
        features['volume_trend'] = market_data['volume'].rolling(20).mean() / market_data['volume'].rolling(60).mean() - 1
        
        # 动量指标
        features['momentum_20d'] = market_data['close'].pct_change(20)
        features['momentum_60d'] = market_data['close'].pct_change(60)
        
        # 可以添加其他宏观经济指标，如利率、经济增长数据等
        
        return features.dropna()
    
    def label_market_states(self, market_data, window=60):
        """根据历史数据标记市场状态"""
        returns = market_data['close'].pct_change(window)
        volatility = market_data['close'].rolling(window).std() / market_data['close'].rolling(window).mean()
        
        # 定义市场状态的规则
        bull_market = (returns > 0.1) & (volatility < 0.15)  # 强劲上涨，波动较小
        bear_market = (returns < -0.1) | (volatility > 0.3)  # 下跌或高波动
        sideways_market = ~(bull_market | bear_market)  # 其他情况视为盘整市场
        
        market_state = pd.Series(index=market_data.index, dtype='object')
        market_state[bull_market] = "Bull"
        market_state[bear_market] = "Bear"
        market_state[sideways_market] = "Sideways"
        
        return market_state
    
    def train(self, market_data):
        """训练市场状态识别模型"""
        features = self.extract_market_features(market_data)
        labels = self.label_market_states(market_data)
        
        # 确保特征和标签有相同的索引
        common_index = features.index.intersection(labels.index)
        X = features.loc[common_index]
        y = labels.loc[common_index]
        
        # 标准化特征
        X_scaled = self.scaler.fit_transform(X)
        
        # 训练随机森林分类器
        self.model.fit(X_scaled, y)
        return self
    
    def predict_market_state(self, market_data):
        """预测当前市场状态"""
        features = self.extract_market_features(market_data)
        latest_features = features.iloc[-1:].values
        
        # 标准化特征
        latest_features_scaled = self.scaler.transform(latest_features)
        
        # 预测市场状态
        predicted_state = self.model.predict(latest_features_scaled)[0]
        state_probs = self.model.predict_proba(latest_features_scaled)[0]
        
        return {
            'state': predicted_state,
            'confidence': dict(zip(self.market_states, state_probs))
        }
    
    def get_market_regime_features(self, market_data):
        """获取描述市场状态的特征，可以作为交易模型的输入"""
        features = self.extract_market_features(market_data)
        return features.iloc[-1].to_dict()