In [None]:
class TradeSimulator:
    """交易模擬器 - 新增用於評估模型在實際交易場景的表現"""
    
    def __init__(self,
                predictions: np.ndarray,
                actual_returns: np.ndarray,
                initial_capital: float = 1000000,
                transaction_cost: float = 0.001425,
                position_size: float = 0.1):  # 添加倉位控制
        self.predictions = predictions
        self.actual_returns = actual_returns
        self.initial_capital = initial_capital
        self.transaction_cost = transaction_cost
        self.position_size = position_size  # 每次交易使用的資金比例
        
    def evaluate(self) -> Dict[str, float]:
        capital = self.initial_capital
        position = 0
        trades = []
        returns = []
        
        # 添加止損和獲利目標
        stop_loss = 0.02  # 2%止損
        take_profit = 0.05  # 5%獲利了結
        
        for i, (pred, actual) in enumerate(zip(self.predictions, self.actual_returns)):
            # 計算當前倉位價值
            position_value = capital * self.position_size
            
            if pred == 1 and position == 0:  # 買入訊號
                trade_amount = min(position_value, capital)
                cost = trade_amount * self.transaction_cost
                capital -= cost
                position = 1
                entry_price = trade_amount
                
            elif (pred == 0 and position == 1) or \
                 (position == 1 and (actual < -stop_loss or actual > take_profit)):  # 賣出訊號或觸及止損/獲利
                cost = position_value * self.transaction_cost
                capital -= cost
                
                # 計算報酬
                capital += position_value * (1 + actual)
                position = 0
                
            returns.append((capital - self.initial_capital) / self.initial_capital)
        
        return {
            'total_return': (capital - self.initial_capital) / self.initial_capital,
            'total_trades': len(trades),
            'win_rate': self._calculate_win_rate(trades),
            'sharpe_ratio': self._calculate_sharpe_ratio(returns),
            'max_drawdown': self._calculate_max_drawdown(returns)
        }
    
    def _calculate_win_rate(self, trades: List[Dict]) -> float:
        """計算勝率"""
        if not trades:
            return 0.0
            
        profitable_trades = sum(
            1 for i in range(0, len(trades), 2)
            if i + 1 < len(trades) and trades[i+1]['capital'] > trades[i]['capital']
        )
        return profitable_trades / (len(trades) // 2) if len(trades) >= 2 else 0.0
    
    def _calculate_sharpe_ratio(self, returns: List[float]) -> float:
        """計算夏普比率"""
        if not returns:
            return 0.0
            
        returns_array = np.array(returns)
        return (np.mean(returns_array) / np.std(returns_array)) * np.sqrt(252)
    
    def _calculate_max_drawdown(self, returns: List[float]) -> float:
        """計算最大回撤"""
        if not returns:
            return 0.0
            
        cumulative = np.array(returns) + 1
        running_max = np.maximum.accumulate(cumulative)
        drawdown = running_max - cumulative
        return np.max(drawdown)