In [None]:
import pandas as pd
import numpy as np
import os
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.dates as mdates
from typing import List, Dict, Optional, Tuple

# 全局配置
pd.set_option('display.max_columns', None)
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams['axes.unicode_minus'] = False

# 基础必需字段
BASE_REQUIRED_COLS = [
    "stock_code", "stock_name", "date", "open", "close", "high", "low",
    "volume", "turnover_ratio", "sw_l1_industry_name",
    "ma5", "ma10", "ma20",
    "high_limit", "low_limit"
]

class NoFutureStockSelector:
    def __init__(self, data_path: str, 
                 output_dir: str = "selected_stocks_no_future",
                 result_dir: str = "selection_results"):
        self.data_path = data_path
        self.output_dir = output_dir
        self.result_dir = result_dir
        
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.result_dir, exist_ok=True)
        
        self.raw_data = None
        self.processed_data = None
        self.trading_dates = []
        self.all_selected = None
    
    def load_base_data(self) -> bool:
        try:
            print("加载原始数据...")
            self.raw_data = pd.read_parquet(self.data_path)
            print(f"原始数据规模：{len(self.raw_data)} 行 × {len(self.raw_data.columns)} 列")
            
            missing_cols = [col for col in BASE_REQUIRED_COLS if col not in self.raw_data.columns]
            if missing_cols:
                raise ValueError(f"❌ 缺失基础必需字段：{missing_cols}")
            
            return True
        except Exception as e:
            print(f"数据加载失败：{str(e)}")
            return False
    
    def calculate_technical_indicators(self) -> bool:
        if self.raw_data is None:
            print("请先加载基础数据")
            return False
            
        try:
            df = self.raw_data[BASE_REQUIRED_COLS].copy()
            df = df.sort_values(["stock_code", "date"])
            df["date"] = pd.to_datetime(df["date"])
            
            # 1. 计算成交量均线
            print("计算成交量均线指标...")
            df['vol_ma5'] = df.groupby('stock_code')['volume'].transform(
                lambda x: x.rolling(window=5, min_periods=5).mean()
            )
            df['vol_ma10'] = df.groupby('stock_code')['volume'].transform(
                lambda x: x.rolling(window=10, min_periods=10).mean()
            )
            
            # 2. 计算RSI指标
            print("计算RSI指标...")
            def calculate_rsi(series, window):
                delta = series.diff(1)
                gain = delta.where(delta > 0, 0)
                loss = -delta.where(delta < 0, 0)
                
                avg_gain = gain.rolling(window=window, min_periods=window).mean()
                avg_loss = loss.rolling(window=window, min_periods=window).mean()
                
                avg_loss = avg_loss.replace(0, 1e-8)
                rs = avg_gain / avg_loss
                return 100 - (100 / (1 + rs))
            
            df['rsi_6'] = df.groupby('stock_code')['close'].transform(
                lambda x: calculate_rsi(x, 6)
            )
            df['rsi_12'] = df.groupby('stock_code')['close'].transform(
                lambda x: calculate_rsi(x, 12)
            )
            
            # 3. 计算KDJ指标（优化版）
            print("计算KDJ指标...")
            def calculate_kdj(group):
                n = 9
                m1 = 3
                m2 = 3
                
                low_min = group['low'].rolling(window=n, min_periods=n).min()
                high_max = group['high'].rolling(window=n, min_periods=n).max()
                
                rsv = (group['close'] - low_min) / (high_max - low_min) * 100
                rsv = rsv.replace([np.inf, -np.inf], np.nan)
                
                # 使用指数平滑替代循环计算
                k = rsv.ewm(alpha=1/m1, adjust=False).mean()
                k.iloc[:n-1] = 50
                
                d = k.ewm(alpha=1/m2, adjust=False).mean()
                d.iloc[:n-1] = 50
                
                j = 3 * k - 2 * d
                
                return pd.DataFrame({'kdj_k': k, 'kdj_d': d, 'kdj_j': j}, index=group.index)
            
            kdj_results = df.groupby('stock_code', group_keys=False).apply(calculate_kdj)
            df = df.join(kdj_results)
            
            # 4. 计算MACD指标
            print("计算MACD指标...")
            def calculate_macd(series, fast_period=12, slow_period=26, signal_period=9):
                fast_ema = series.ewm(span=fast_period, adjust=False).mean()
                slow_ema = series.ewm(span=slow_period, adjust=False).mean()
                
                macd_line = fast_ema - slow_ema
                signal_line = macd_line.ewm(span=signal_period, adjust=False).mean()
                macd_hist = macd_line - signal_line
                
                return pd.DataFrame({
                    'macd_line': macd_line,
                    'signal_line': signal_line,
                    'macd_hist': macd_hist
                }, index=series.index)
            
            macd_results = df.groupby('stock_code', group_keys=False)['close'].apply(
                lambda x: calculate_macd(x)
            )
            df = df.join(macd_results)
            
            # 标准化MACD字段名
            df['macd'] = df['macd_hist']
            df['macd_diff'] = df['macd_line']
            df['macd_dea'] = df['signal_line']
            
            self.processed_data = df
            print("技术指标计算完成")
            return True
        except Exception as e:
            print(f"指标计算失败：{str(e)}")
            return False
    
    def clean_processed_data(self) -> bool:
        if self.processed_data is None:
            print("请先计算技术指标")
            return False
            
        try:
            all_required_cols = BASE_REQUIRED_COLS + [
                'vol_ma5', 'vol_ma10', 'rsi_6', 'rsi_12',
                'kdj_k', 'kdj_d', 'kdj_j', 'macd', 'macd_diff', 'macd_dea',
                'macd_line', 'signal_line', 'macd_hist'
            ]
            
            numeric_cols = [col for col in all_required_cols 
                           if col not in ["stock_code", "stock_name", "date", "sw_l1_industry_name"]]
            
            for col in numeric_cols:
                self.processed_data[col] = pd.to_numeric(self.processed_data[col], errors="coerce")
            
            self.processed_data = self.processed_data.sort_values(["stock_code", "date"])
            
            for col in numeric_cols:
                self.processed_data[col] = self.processed_data.groupby("stock_code")[col].ffill()
            
            self.processed_data["sw_l1_industry_name"] = self.processed_data.groupby("stock_code")["sw_l1_industry_name"].ffill()
            
            initial_count = len(self.processed_data)
            self.processed_data = self.processed_data.dropna(subset=all_required_cols)
            print(f"数据清洗完成，保留 {len(self.processed_data)} 行（删除 {initial_count - len(self.processed_data)} 行无效数据）")
            
            self.trading_dates = sorted(self.processed_data["date"].unique())
            print(f"有效交易日范围：{self.trading_dates[0].strftime('%Y-%m-%d')} 至 {self.trading_dates[-1].strftime('%Y-%m-%d')}，共 {len(self.trading_dates)} 天")
            
            return True
        except Exception as e:
            print(f"数据清洗失败：{str(e)}")
            return False
    
    def calculate_t_day_indicators(self, date_data: pd.DataFrame) -> pd.DataFrame:
        df = date_data.copy()
        
        df["t_day_return"] = (df["close"] - df["open"]) / df["open"]
        df["above_ma5"] = df["close"] > df["ma5"]
        df["ma5_up"] = df.groupby("stock_code")["ma5"].transform(
            lambda x: x > x.shift(1)
        )
        df["volume_expand"] = df["volume"] > 1.2 * df["vol_ma5"]
        df["is_limit_up"] = df["close"] >= df["high_limit"] * 0.995
        df["is_limit_down"] = df["close"] <= df["low_limit"] * 1.005
        df["rsi_normal"] = (df["rsi_6"] > 30) & (df["rsi_6"] < 70)
        df["kdj_gold_cross"] = (df["kdj_j"] > df["kdj_k"]) & (df["kdj_j"] > df["kdj_d"])
        
        # 新增：计算短期波动率（T日振幅）
        df["daily_amplitude"] = (df["high"] - df["low"]) / df["open"]
        
        return df
    
    def judge_market_sentiment(self, date: datetime) -> Optional[Dict]:
        date_data = self.processed_data[self.processed_data["date"] == date].copy()
        if date_data.empty:
            return None
        
        date_data = self.calculate_t_day_indicators(date_data)
        
        up_ratio = (date_data["t_day_return"] > 0).mean()
        limit_up_count = date_data["is_limit_up"].sum()
        limit_down_count = date_data["is_limit_down"].sum()
        
        breadth_score = 0
        if up_ratio >= 0.6:
            breadth_score += 20
        elif up_ratio >= 0.5:
            breadth_score += 10
        
        if limit_up_count >= 50:
            breadth_score += 15
        elif limit_up_count >= 20:
            breadth_score += 5
            
        if limit_down_count <= 5:
            breadth_score += 5
        elif limit_down_count <= 10:
            breadth_score += 2
        
        avg_turnover = date_data["turnover_ratio"].mean()
        volume_expand_ratio = date_data["volume_expand"].mean()
        
        activity_score = 0
        if avg_turnover >= 3.0:
            activity_score += 15
        elif avg_turnover >= 2.0:
            activity_score += 5
            
        if volume_expand_ratio >= 0.4:
            activity_score += 15
        elif volume_expand_ratio >= 0.3:
            activity_score += 5
        
        above_ma5_ratio = date_data["above_ma5"].mean()
        ma5_up_ratio = date_data["ma5_up"].mean()
        
        trend_score = 0
        if above_ma5_ratio >= 0.6:
            trend_score += 15
        elif above_ma5_ratio >= 0.5:
            trend_score += 5
            
        if ma5_up_ratio >= 0.6:
            trend_score += 15
        elif ma5_up_ratio >= 0.5:
            trend_score += 5
        
        total_score = breadth_score + activity_score + trend_score
        if total_score >= 70:
            sentiment = "strong"
        elif total_score >= 40:
            sentiment = "neutral"
        else:
            sentiment = "weak"
            
        return {
            "date": date,
            "sentiment": sentiment,
            "total_score": total_score,
            "up_ratio": round(up_ratio * 100, 1),
            "limit_up_count": limit_up_count
        }
    
    def select_strong_industries(self, date: datetime, top_n: int = 3) -> List[str]:
        date_data = self.processed_data[self.processed_data["date"] == date].copy()
        if date_data.empty:
            return []
            
        date_data = self.calculate_t_day_indicators(date_data)
        
        industry_metrics = date_data.groupby("sw_l1_industry_name").agg({
            "t_day_return": ["mean", lambda x: (x > 0).mean()],
            "turnover_ratio": "mean",
            "above_ma5": "mean"
        }).reset_index()
        
        industry_metrics.columns = [
            "industry", "avg_return", "up_ratio", "avg_turnover", "above_ma5_ratio"
        ]
        
        max_return = industry_metrics["avg_return"].max()
        industry_metrics["norm_return"] = industry_metrics["avg_return"] / (max_return + 1e-8)
        
        industry_metrics["strength_score"] = (
            industry_metrics["norm_return"] * 40 +
            industry_metrics["up_ratio"] * 20 +
            industry_metrics["avg_turnover"].clip(0, 5)/5 * 20 +
            industry_metrics["above_ma5_ratio"] * 20
        )
        
        industry_size = date_data.groupby("sw_l1_industry_name").size().reset_index(name="count")
        valid_industries = industry_size[industry_size["count"] >= 5]["sw_l1_industry_name"].tolist()
        filtered_industries = industry_metrics[industry_metrics["industry"].isin(valid_industries)]
        
        top_industries = filtered_industries.nlargest(top_n, "strength_score")["industry"].tolist()
        
        print(f"\nT日强行业TOP{len(top_industries)}：")
        for idx, industry in enumerate(top_industries, 1):
            score = filtered_industries[filtered_industries["industry"] == industry]["strength_score"].values[0]
            avg_ret = filtered_industries[filtered_industries["industry"] == industry]["avg_return"].values[0]
            print(f"  {idx}. {industry}：强度{score:.1f}分 | 平均收益{avg_ret*100:.2f}%")
        
        return top_industries
    
    def select_stocks_for_date(self, date: datetime, stocks_per_industry: int = 2) -> Optional[pd.DataFrame]:
        """优化版选股逻辑：增加短期趋势判断和风险过滤"""
        sentiment_info = self.judge_market_sentiment(date)
        if not sentiment_info or sentiment_info["sentiment"] == "weak":
            print(f"T日({date.strftime('%Y-%m-%d')}) 情绪：{sentiment_info['sentiment'] if sentiment_info else '无数据'}，不选股")
            return None
        
        print(f"\nT日({date.strftime('%Y-%m-%d')}) 情绪：{sentiment_info['sentiment'].upper()}（得分：{sentiment_info['total_score']}）")
        print(f"  上涨个股占比：{sentiment_info['up_ratio']}% | 涨停数：{sentiment_info['limit_up_count']}")
        
        strong_industries = self.select_strong_industries(date)
        if not strong_industries:
            print("  无符合条件的强行业，不选股")
            return None
        
        date_data = self.processed_data[self.processed_data["date"] == date].copy()
        date_data = self.calculate_t_day_indicators(date_data)
        strong_industry_stocks = date_data[date_data["sw_l1_industry_name"].isin(strong_industries)].copy()
        
        # 新增：风险过滤（剔除高风险个股）
        strong_industry_stocks = strong_industry_stocks[
            (strong_industry_stocks["turnover_ratio"] >= 0.8) &  # 保证流动性
            (strong_industry_stocks["t_day_return"] < 0.09) &    # 避免当日涨幅过高
            (strong_industry_stocks["rsi_6"] < 65) &             # 避免超买
            (strong_industry_stocks["daily_amplitude"] < 0.08)   # 振幅小于8%，降低波动风险
        ].copy()
        
        if strong_industry_stocks.empty:
            print("  经过风险过滤后无符合条件的个股")
            return None
        
        # 新增：计算短期趋势持续性指标
        def calculate_short_term_trend(group):
            # 计算T-2至T日的累计收益率（反映短期趋势）
            group['short_term_return'] = group['close'].pct_change(periods=2)
            # 计算5日均线斜率（反映趋势方向）
            group['ma5_slope'] = group['ma5'].diff(periods=1) / group['ma5'].shift(1)
            return group
        
        # 按股票分组计算短期趋势
        strong_industry_stocks = strong_industry_stocks.groupby('stock_code', group_keys=False).apply(calculate_short_term_trend)
        
        # 优化版个股评分体系
        # 1. 趋势得分（30分）：更强调短期持续性
        strong_industry_stocks["trend_score"] = 0
        # 原有趋势指标（15分）
        strong_industry_stocks.loc[strong_industry_stocks["above_ma5"], "trend_score"] += 8
        strong_industry_stocks.loc[strong_industry_stocks["ma5_up"], "trend_score"] += 7
        # 新增短期趋势指标（15分）
        strong_industry_stocks.loc[strong_industry_stocks["short_term_return"] > 0.03, "trend_score"] += 8  # 2日累计涨超3%
        strong_industry_stocks.loc[strong_industry_stocks["ma5_slope"] > 0.01, "trend_score"] += 7  # 5日线斜率>1%
        
        # 2. 强度得分（25分）：降低单日收益权重
        strong_industry_stocks["strength_score"] = 0
        top_return_threshold = strong_industry_stocks["t_day_return"].quantile(0.8)  # 更高分位，减少高收益股
        strong_industry_stocks.loc[strong_industry_stocks["t_day_return"] >= top_return_threshold, "strength_score"] += 10
        strong_industry_stocks.loc[strong_industry_stocks["t_day_return"] < 0.07, "strength_score"] += 5  # 单日涨幅<7%加分
        strong_industry_stocks.loc[~strong_industry_stocks["is_limit_down"], "strength_score"] += 10
        
        # 3. 资金得分（25分）
        strong_industry_stocks["capital_score"] = 0
        strong_industry_stocks.loc[strong_industry_stocks["volume_expand"], "capital_score"] += 15
        strong_industry_stocks.loc[
            (strong_industry_stocks["turnover_ratio"] >= 1) & 
            (strong_industry_stocks["turnover_ratio"] <= 8), 
            "capital_score"
        ] += 10
        
        # 4. 风险控制得分（20分）：提高权重
        strong_industry_stocks["risk_score"] = 0
        strong_industry_stocks.loc[strong_industry_stocks["rsi_normal"], "risk_score"] += 10
        strong_industry_stocks.loc[strong_industry_stocks["daily_amplitude"] < 0.05, "risk_score"] += 5  # 低波动加分
        strong_industry_stocks.loc[strong_industry_stocks["kdj_gold_cross"], "risk_score"] += 5
        
        # 总得分
        strong_industry_stocks["total_score"] = (
            strong_industry_stocks["trend_score"] +
            strong_industry_stocks["strength_score"] +
            strong_industry_stocks["capital_score"] +
            strong_industry_stocks["risk_score"]
        )
        
        # 按行业筛选个股（每行业最多2只）
        selected_stocks = []
        for industry in strong_industries:
            industry_stocks = strong_industry_stocks[
                (strong_industry_stocks["sw_l1_industry_name"] == industry) &
                (~strong_industry_stocks["is_limit_up"])
            ]
            
            if len(industry_stocks) == 0:
                print(f"  {industry} 无符合条件个股（已排除涨停股）")
                continue
                
            top_stocks = industry_stocks.nlargest(stocks_per_industry, "total_score")
            selected_stocks.append(top_stocks)
        
        if not selected_stocks:
            print("  无符合条件的个股（已排除涨停股）")
            return None
        
        # 控制每日推荐总数不超过6只
        final_selected = pd.concat(selected_stocks)
        if len(final_selected) > 6:
            final_selected = final_selected.nlargest(6, "total_score")
        
        final_selected["selection_date"] = date
        final_selected["market_sentiment"] = sentiment_info["sentiment"]
        final_selected["sentiment_score"] = sentiment_info["total_score"]
        # 保留风险指标用于后续仓位计算
        final_selected["daily_amplitude"] = final_selected["daily_amplitude"]
        
        result_df = final_selected[
            ["selection_date", "stock_code", "stock_name", "sw_l1_industry_name",
             "t_day_return", "turnover_ratio", "daily_amplitude",
             "trend_score", "strength_score", "capital_score", "risk_score", 
             "total_score", "market_sentiment"]
        ].copy()
        
        result_df.columns = [
            "选股日期", "股票代码", "股票名称", "所属行业",
            "T日收益率(%)", "换手率(%)", "当日振幅",
            "趋势得分", "强度得分", "资金得分", "风险得分", 
            "综合评分", "市场情绪"
        ]
        
        result_df["T日收益率(%)"] = (result_df["T日收益率(%)"] * 100).round(2)
        result_df["换手率(%)"] = result_df["换手率(%)"].round(2)
        result_df["当日振幅"] = (result_df["当日振幅"] * 100).round(2)
        
        return result_df
    
    def run_daily_selection(self, start_date: Optional[str] = None, 
                           end_date: Optional[str] = None, 
                           stocks_per_industry: int = 2) -> bool:
        if not self.trading_dates:
            print("请先完成数据清洗")
            return False
            
        if start_date:
            start_date = pd.to_datetime(start_date)
            selected_dates = [d for d in self.trading_dates if d >= start_date]
        else:
            selected_dates = self.trading_dates
            
        if end_date:
            end_date = pd.to_datetime(end_date)
            selected_dates = [d for d in selected_dates if d <= end_date]
        
        if not selected_dates:
            print("无符合条件的交易日")
            return False
            
        print(f"\n开始按日选股，日期范围：{selected_dates[0].strftime('%Y-%m-%d')} 至 {selected_dates[-1].strftime('%Y-%m-%d')}，共 {len(selected_dates)} 个交易日")
        
        all_results = []
        
        for date in tqdm(selected_dates, desc="选股进度"):
            try:
                daily_result = self.select_stocks_for_date(date, stocks_per_industry)
                
                if daily_result is not None and not daily_result.empty:
                    date_str = date.strftime("%Y%m%d")
                    save_path = os.path.join(self.output_dir, f"selected_{date_str}.csv")
                    daily_result.to_csv(save_path, index=False, encoding="utf-8-sig")
                    all_results.append(daily_result)
                    
            except Exception as e:
                print(f"{date.strftime('%Y-%m-%d')} 选股出错：{str(e)}")
                continue
        
        if all_results:
            self.all_selected = pd.concat(all_results, ignore_index=True)
            print(f"\n选股完成，共推荐 {len(self.all_selected)} 只个股")
            
            self.all_selected.to_csv(
                os.path.join(self.output_dir, "all_selected_stocks.csv"),
                index=False, encoding="utf-8-sig"
            )
            return True
        else:
            print("\n未生成任何选股结果")
            return False
    
    def generate_selection_summary(self):
        if self.all_selected is None or self.all_selected.empty:
            print("无选股结果可汇总")
            return False
        
        sentiment_summary = self.all_selected.groupby("市场情绪").agg({
            "股票代码": "count",
            "综合评分": "mean"
        }).reset_index()
        sentiment_summary.columns = ["市场情绪", "推荐个股数", "平均综合评分"]
        sentiment_summary["平均综合评分"] = sentiment_summary["平均综合评分"].round(1)
        
        industry_counts = self.all_selected["所属行业"].value_counts().head(10).reset_index()
        industry_counts.columns = ["所属行业", "推荐次数"]
        
        score_distribution = pd.cut(
            self.all_selected["综合评分"],
            bins=[0, 60, 70, 80, 90, 100],
            labels=["60分以下", "60-70分", "70-80分", "80-90分", "90分以上"]
        ).value_counts().sort_index().reset_index()
        score_distribution.columns = ["评分区间", "个股数量"]
        
        print("\n===== 选股结果汇总统计 =====")
        print("\n按市场情绪分类：")
        print(sentiment_summary.to_string(index=False))
        
        print("\n推荐次数前10的行业：")
        print(industry_counts.to_string(index=False))
        
        print("\n综合评分分布：")
        print(score_distribution.to_string(index=False))
        
        self.plot_selection_summary(sentiment_summary, industry_counts, score_distribution)
        
        with pd.ExcelWriter(os.path.join(self.result_dir, "selection_summary.xlsx")) as writer:
            sentiment_summary.to_excel(writer, sheet_name="按情绪统计", index=False)
            industry_counts.to_excel(writer, sheet_name="行业分布", index=False)
            score_distribution.to_excel(writer, sheet_name="评分分布", index=False)
            self.all_selected.to_excel(writer, sheet_name="所有推荐个股", index=False)
        
        return True
    
    def plot_selection_summary(self, sentiment_summary, industry_counts, score_distribution):
        plt.figure(figsize=(12, 10))
        
        plt.subplot(2, 2, 1)
        plt.bar(sentiment_summary["市场情绪"], sentiment_summary["推荐个股数"], color=['green', 'yellow', 'red'])
        plt.title("不同市场情绪下的推荐个股数量")
        plt.ylabel("个股数量")
        
        plt.subplot(2, 2, 2)
        plt.barh(industry_counts["所属行业"], industry_counts["推荐次数"], color='steelblue')
        plt.title("行业推荐次数TOP10")
        plt.xlabel("推荐次数")
        plt.gca().invert_yaxis()
        
        plt.subplot(2, 1, 2)
        plt.pie(score_distribution["个股数量"], labels=score_distribution["评分区间"], autopct='%1.1f%%')
        plt.title("综合评分分布占比")
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.result_dir, "selection_summary.png"), dpi=300)
        plt.close()
        
        print(f"\n汇总图表已保存至 {self.result_dir} 目录")


class StrategyEvaluator(NoFutureStockSelector):
    """策略评价类，包含T+1/T+2买入T+3卖出的收益计算"""
    
    def __init__(self, data_path: str, 
                 output_dir: str = "selected_stocks_no_future",
                 result_dir: str = "selection_results",
                 performance_dir: str = "strategy_performance"):
        super().__init__(data_path, output_dir, result_dir)
        self.performance_dir = performance_dir
        os.makedirs(self.performance_dir, exist_ok=True)
        
        self.trade_records = None
        self.daily_performance = None
        self.overall_stats = None
        self.price_data = None
    
    def prepare_price_data(self) -> bool:
        """准备用于回测的价格数据"""
        try:
            self.price_data = self.raw_data[
                ["stock_code", "date", "open", "close"]
            ].copy()
            self.price_data["date"] = pd.to_datetime(self.price_data["date"])
            self.price_data = self.price_data.sort_values(["stock_code", "date"])
            self.price_data = self.price_data.set_index(["stock_code", "date"])
            return True
        except Exception as e:
            print(f"价格数据准备失败：{str(e)}")
            return False
    
    def get_future_price(self, stock_code: str, base_date: datetime, days: int, price_type: str = "open") -> Optional[float]:
        """获取未来特定日期的价格"""
        try:
            stock_dates = self.price_data.loc[stock_code].index.sort_values()
            base_idx = np.where(stock_dates == base_date)[0][0]
            
            if base_idx + days >= len(stock_dates):
                return None
                
            target_date = stock_dates[base_idx + days]
            return self.price_data.loc[(stock_code, target_date), price_type]
        except (IndexError, KeyError):
            return None
    
    def calculate_strategy_returns(self) -> bool:
        """计算T+1和T+2日开盘各买50%，T+3日收盘卖出的策略收益（优化版：风险适配仓位）"""
        if self.all_selected is None or self.all_selected.empty:
            print("无选股结果可计算收益")
            return False
            
        if self.price_data is None and not self.prepare_price_data():
            print("无法准备价格数据，无法计算收益")
            return False
        
        self.all_selected["选股日期"] = pd.to_datetime(self.all_selected["选股日期"])
        trade_records = []
        
        selection_dates = self.all_selected["选股日期"].unique()
        for select_date in tqdm(selection_dates, desc="计算策略收益"):
            daily_stocks = self.all_selected[self.all_selected["选股日期"] == select_date]
            
            for _, stock in daily_stocks.iterrows():
                stock_code = stock["股票代码"]
                stock_name = stock["股票名称"]
                select_date = stock["选股日期"]
                daily_amplitude = stock["当日振幅"] / 100  # 转换为小数
                
                # 获取所需价格
                t1_open = self.get_future_price(stock_code, select_date, 1, "open")
                t2_open = self.get_future_price(stock_code, select_date, 2, "open")
                t3_close = self.get_future_price(stock_code, select_date, 3, "close")
                
                if t1_open is None or t2_open is None or t3_close is None:
                    continue
                
                # 优化：基于当日振幅计算风险系数，调整仓位
                # 振幅越大，风险系数越低，投入资金越少
                risk_coefficient = 1 - (daily_amplitude / 0.1)  # 0.1为振幅上限
                risk_coefficient = max(risk_coefficient, 0.5)  # 最低保留50%仓位
                
                # 基础投资金额10000元，根据风险调整
                base_invest = 10000
                adjust_invest = base_invest * risk_coefficient
                
                # 分两次买入
                t1_shares = (adjust_invest * 0.5) / t1_open
                t2_shares = (adjust_invest * 0.5) / t2_open
                total_shares = t1_shares + t2_shares
                
                # 计算收益
                sell_amount = total_shares * t3_close
                total_return = sell_amount - adjust_invest
                return_ratio = total_return / adjust_invest
                
                trade_records.append({
                    "选股日期": select_date,
                    "股票代码": stock_code,
                    "股票名称": stock_name,
                    "所属行业": stock["所属行业"],
                    "T+1开盘价": t1_open,
                    "T+2开盘价": t2_open,
                    "T+3收盘价": t3_close,
                    "风险系数": round(risk_coefficient, 2),
                    "总投入": round(adjust_invest, 2),
                    "总卖出": round(sell_amount, 2),
                    "收益金额": round(total_return, 2),
                    "收益率(%)": round(return_ratio * 100, 2),
                    "T+1买入股数": round(t1_shares, 2),
                    "T+2买入股数": round(t2_shares, 2),
                    "总股数": round(total_shares, 2)
                })
        
        if not trade_records:
            print("无有效交易记录可计算")
            return False
        
        self.trade_records = pd.DataFrame(trade_records)
        self.trade_records.to_csv(
            os.path.join(self.performance_dir, "trade_records.csv"),
            index=False, encoding="utf-8-sig"
        )
        
        # 计算每日策略表现
        self.daily_performance = self.trade_records.groupby("选股日期").agg({
            "收益率(%)": ["mean", "count", lambda x: (x > 0).mean() * 100]
        }).reset_index()
        self.daily_performance.columns = [
            "选股日期", "平均收益率(%)", "交易股票数", "盈利胜率(%)"
        ]
        self.daily_performance.to_csv(
            os.path.join(self.performance_dir, "daily_performance.csv"),
            index=False, encoding="utf-8-sig"
        )
        
        # 计算整体策略统计指标
        self.calculate_strategy_stats()
        
        return True
    
    def calculate_strategy_stats(self) -> None:
        """计算策略整体表现统计指标"""
        if self.trade_records is None:
            return
            
        total_trades = len(self.trade_records)
        profitable_trades = sum(self.trade_records["收益率(%)"] > 0)
        loss_trades = total_trades - profitable_trades
        
        self.overall_stats = {
            "总交易次数": total_trades,
            "盈利次数": profitable_trades,
            "亏损次数": loss_trades,
            "总胜率(%)": round(profitable_trades / total_trades * 100, 2),
            "平均收益率(%)": round(self.trade_records["收益率(%)"].mean(), 2),
            "中位数收益率(%)": round(self.trade_records["收益率(%)"].median(), 2),
            "最大盈利(%)": round(self.trade_records["收益率(%)"].max(), 2),
            "最大亏损(%)": round(self.trade_records["收益率(%)"].min(), 2),
            "累计收益率(%)": round(((1 + self.trade_records["收益率(%)"]/100).prod() - 1) * 100, 2),
            "收益标准差": round(self.trade_records["收益率(%)"].std(), 2),
            "夏普比率": round(
                (self.trade_records["收益率(%)"].mean() / 100) / 
                (self.trade_records["收益率(%)"].std() / 100 + 1e-8), 2
            )
        }
        
        # 按行业统计
        industry_stats = self.trade_records.groupby("所属行业").agg({
            "收益率(%)": ["mean", "count", lambda x: (x > 0).mean() * 100],
            "股票代码": "nunique"
        }).reset_index()
        industry_stats.columns = [
            "行业", "平均收益率(%)", "交易次数", "胜率(%)", "涉及股票数"
        ]
        industry_stats = industry_stats.sort_values("平均收益率(%)", ascending=False)
        
        # 保存统计结果
        with pd.ExcelWriter(os.path.join(self.performance_dir, "strategy_stats.xlsx")) as writer:
            pd.DataFrame([self.overall_stats]).to_excel(writer, sheet_name="整体表现", index=False)
            industry_stats.to_excel(writer, sheet_name="行业表现", index=False)
            self.trade_records.sort_values("收益率(%)", ascending=False).head(10).to_excel(writer, sheet_name="最佳个股", index=False)
            self.trade_records.sort_values("收益率(%)", ascending=True).head(10).to_excel(writer, sheet_name="最差个股", index=False)
    
    def plot_strategy_performance(self) -> None:
        """可视化策略表现"""
        if self.daily_performance is None or self.trade_records is None:
            print("无绩效数据可可视化")
            return
            
        plt.figure(figsize=(15, 12))
        
        plt.subplot(3, 1, 1)
        plt.plot(
            self.daily_performance["选股日期"], 
            self.daily_performance["平均收益率(%)"], 
            'b-', alpha=0.7
        )
        plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
        plt.title("每日平均收益率走势")
        plt.ylabel("收益率(%)")
        plt.grid(True, alpha=0.3)
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
        plt.gcf().autofmt_xdate()
        
        plt.subplot(3, 1, 2)
        sorted_dates = self.daily_performance.sort_values("选股日期")
        cumulative_return = (1 + sorted_dates["平均收益率(%)"]/100).cumprod() - 1
        plt.plot(
            sorted_dates["选股日期"], 
            cumulative_return * 100, 
            'g-', alpha=0.7
        )
        plt.title(f"策略累计收益率：{self.overall_stats['累计收益率(%)']}%")
        plt.ylabel("累计收益率(%)")
        plt.grid(True, alpha=0.3)
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
        plt.gcf().autofmt_xdate()
        
        plt.subplot(3, 1, 3)
        plt.hist(
            self.trade_records["收益率(%)"], 
            bins=30, 
            alpha=0.7, 
            color='orange'
        )
        plt.axvline(x=0, color='r', linestyle='--', alpha=0.3)
        plt.title(f"个股收益率分布（平均：{self.overall_stats['平均收益率(%)']}%）")
        plt.xlabel("收益率(%)")
        plt.ylabel("交易次数")
        plt.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.performance_dir, "strategy_performance.png"), dpi=300)
        plt.close()
        
        print(f"\n策略表现图表已保存至 {self.performance_dir} 目录")
    
    def print_strategy_evaluation(self) -> None:
        """打印策略评价结果"""
        if self.overall_stats is None:
            print("无策略评价数据")
            return
            
        print("\n===== 策略表现评价 =====")
        print(f"交易策略：T+1日和T+2日开盘价各买入50%，T+3日收盘价卖出")
        print(f"评价周期：{self.trade_records['选股日期'].min().strftime('%Y-%m-%d')} 至 {self.trade_records['选股日期'].max().strftime('%Y-%m-%d')}")
        print("\n----- 核心指标 -----")
        print(f"总交易次数：{self.overall_stats['总交易次数']}")
        print(f"胜率：{self.overall_stats['总胜率(%)']}%（盈利{self.overall_stats['盈利次数']}次 / 亏损{self.overall_stats['亏损次数']}次）")
        print(f"平均单次收益率：{self.overall_stats['平均收益率(%)']}%")
        print(f"累计收益率：{self.overall_stats['累计收益率(%)']}%")
        print(f"最大盈利：{self.overall_stats['最大盈利(%)']}%")
        print(f"最大亏损：{self.overall_stats['最大亏损(%)']}%")
        print(f"风险调整后收益（夏普比率）：{self.overall_stats['夏普比率']}")
        
        # 打印表现最佳的3个行业
        industry_stats = pd.read_excel(
            os.path.join(self.performance_dir, "strategy_stats.xlsx"), 
            sheet_name="行业表现"
        )
        print("\n----- 表现最佳的3个行业 -----")
        print(industry_stats.head(3).to_string(index=False))


if __name__ == "__main__":
    # 配置参数
    DATA_PATH = r"D:\workspace\xiaoyao\data\factortable.parquet"  # 替换为实际数据路径
    OUTPUT_DIR = "selected_stocks_optimized"  # 优化版选股结果目录
    RESULT_DIR = "selection_summary_optimized"  # 优化版汇总结果目录
    PERFORMANCE_DIR = "strategy_performance_optimized"  # 优化版策略表现目录
    
    # 选股日期范围
    START_DATE = "2025-08-01"
    END_DATE = "2025-09-29"  # 与之前回测周期保持一致，便于对比
    
    # 执行优化版选股与策略评价
    evaluator = StrategyEvaluator(DATA_PATH, OUTPUT_DIR, RESULT_DIR, PERFORMANCE_DIR)
    if evaluator.load_base_data():
        if evaluator.calculate_technical_indicators():
            if evaluator.clean_processed_data():
                if evaluator.run_daily_selection(start_date=START_DATE, end_date=END_DATE):
                    evaluator.generate_selection_summary()
                    if evaluator.calculate_strategy_returns():
                        evaluator.plot_strategy_performance()
                        evaluator.print_strategy_evaluation()
    
    print("\n所有选股与策略评价任务完成")
