In [36]:
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
from datetime import datetime, timedelta
import logging
from pathlib import Path
import json

In [37]:
@dataclass
class PortfolioConfig:
    """投資組合配置參數"""
    # 基礎路徑設定
    BASE_DIR: str = "D:/Min/Python/Project/FA_Data"
    META_DATA_DIR: str = "meta_data"
    PORTFOLIO_DIR: str = "portfolio"
    LOG_DIR: str = "logs"
    
    @staticmethod
    def _default_portfolio_params():
        return {
            'single_stock_max_weight': 0.2,  # 單一股票最大權重
            'single_stock_min_weight': 0.02,  # 單一股票最小權重
            'single_industry_max_weight': 0.35,  # 單一產業最大權重
            'min_cash_ratio': 0.05,  # 最低現金比例
            'risk_free_rate': 0.015,  # 無風險利率
            'rebalance_threshold': 0.1,  # 再平衡閾值
            'stop_loss_ratio': -0.1,  # 停損比例
            'take_profit_ratio': 0.2,  # 獲利了結比例
        }
    
    @staticmethod
    def _default_risk_params():
        return {
            'max_portfolio_volatility': 0.25,  # 最大組合波動率
            'max_drawdown_threshold': -0.15,  # 最大回撤限制
            'var_confidence_level': 0.95,  # VaR信心水準
            'position_limit': 20,  # 最大持股數量
        }
    
    @staticmethod
    def _default_signal_threshold():
        return {
            'buy': 0.7,  # 買入信號閾值
            'sell': 0.3,  # 賣出信號閾值
        }
    
    @staticmethod
    def _default_trade_cost():
        return {
            'commission_rate': 0.001425,  # 手續費率
            'tax_rate': 0.003,  # 證交稅率
            'slippage': 0.001,  # 滑價率
        }
    
    PORTFOLIO_PARAMS: Dict = field(default_factory=_default_portfolio_params)
    RISK_PARAMS: Dict = field(default_factory=_default_risk_params)
    SIGNAL_THRESHOLD: Dict = field(default_factory=_default_signal_threshold)
    TRADE_COST: Dict = field(default_factory=_default_trade_cost)

In [38]:
class Position:
    """持倉部位類"""
    def __init__(self, 
                 stock_id: str, 
                 shares: int, 
                 avg_cost: float, 
                 industry: str,
                 entry_time: datetime = None):
        self.stock_id = stock_id
        self.shares = shares
        self.avg_cost = avg_cost
        self.industry = industry
        self.current_price = avg_cost
        self.entry_time = entry_time or datetime.now()
        self.last_update = datetime.now()
        self.trade_history: List[Dict] = []
    
    @property
    def market_value(self) -> float:
        """計算市值"""
        return self.shares * self.current_price
    
    @property
    def returns(self) -> float:
        """計算報酬率"""
        return self.current_price / self.avg_cost - 1
    
    def update_price(self, new_price: float):
        """更新價格"""
        self.current_price = new_price
        self.last_update = datetime.now()
    
    def add_trade_record(self, trade_type: str, shares: int, price: float):
        """添加交易記錄"""
        self.trade_history.append({
            'time': datetime.now(),
            'type': trade_type,
            'shares': shares,
            'price': price
        })

In [39]:
class PortfolioManager:
    """投資組合管理核心類"""
### 1. 初始化與配置相關    
    def __init__(self, config: PortfolioConfig = None, initial_cash: float = 1000000):
        self.config = config or PortfolioConfig()
        self.setup_logging()
        self.positions: Dict[str, Position] = {}
        self.cash: float = initial_cash  # 設定初始資金
        self.industry_weights: Dict[str, float] = {}
        self.history: List[Dict] = []
        self.performance: Dict = {}
        self._initialize_directories()
        # 添加模型預測相關配置
        self.model_predictions = {}
        self.prediction_history = []
        self.signal_confidence_threshold = 0.8  # 預測信心閾值
        
        # 記錄初始狀態
        if initial_cash > 0:
            self._record_transaction(
                stock_id='',
                action='initialize',
                shares=0,
                price=0.0,
                reason='initial_funding'
            )
    
    def setup_logging(self):
        """設置日誌系統"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(f"{self.config.BASE_DIR}/{self.config.LOG_DIR}/portfolio.log", 
                                  encoding='utf-8'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger('PortfolioManager')
    
    def _initialize_directories(self):
        """初始化必要目錄"""
        for dir_name in [self.config.META_DATA_DIR, self.config.PORTFOLIO_DIR, self.config.LOG_DIR]:
            directory = Path(f"{self.config.BASE_DIR}/{dir_name}")
            directory.mkdir(parents=True, exist_ok=True)
### 2. 核心功能方法        
    def add_position(self, stock_id: str, shares: int, price: float, industry: str) -> bool:
        """新增持倉"""
        try:
            # 先驗證輸入
            if not self._validate_input(stock_id, shares, price, industry):
                return False
            cost = shares * price * (1 + self.config.TRADE_COST['commission_rate'])
            
            # 檢查現金是否足夠
            if cost > self.cash:
                self.logger.warning(f"現金不足,無法購買 {stock_id}")
                return False
            
            # 檢查持倉限制
            if not self._check_position_limits(stock_id, shares * price, industry):
                return False
            
            # 建立新部位
            position = Position(stock_id, shares, price, industry)
            self.positions[stock_id] = position
            
            # 扣除現金
            self.cash -= cost
            
            # 更新產業權重
            self._update_industry_weights()
            
            # 記錄交易
            self._record_transaction(stock_id, 'buy', shares, price, 'new_position')
            
            return True
        except Exception as e:
            self.logger.error(f"新增持倉失敗: {str(e)}")
            return False
            
    def remove_position(self, stock_id: str) -> bool:
        """移除持倉"""
        try:
            if stock_id in self.positions:
                position = self.positions[stock_id]
                # 計算賣出所得(扣除手續費和交易稅)
                proceeds = position.market_value * (
                    1 - self.config.TRADE_COST['commission_rate'] - 
                    self.config.TRADE_COST['tax_rate']
                )
                
                # 記錄交易
                self._record_transaction(
                    stock_id, 
                    'sell', 
                    position.shares, 
                    position.current_price,
                    'close_position'
                )
                
                # 更新現金和持倉
                self.cash += proceeds
                del self.positions[stock_id]
                
                # 更新產業權重
                self._update_industry_weights()
                
                return True
            return False
        except Exception as e:
            self.logger.error(f"移除持倉失敗: {str(e)}")
            return False
    
    def update_position(self, stock_id: str, current_price: float) -> None:
        """更新持倉價格"""
        if stock_id in self.positions:
            self.positions[stock_id].update_price(current_price)
            self._update_industry_weights()

    def update_portfolio_by_prediction(self, predictions: Dict[str, float]):
        """改進的預測更新邏輯"""
        self.logger.info("收到新的預測信號")
        
        for stock_id, pred_data in predictions.items():
            # 評估預測信號
            is_valid, reason = self.evaluate_prediction_signal(stock_id, pred_data)
            if not is_valid:
                self.logger.warning(f"{stock_id} 預測信號被拒絕: {reason}")
                continue
                
            prediction = pred_data['probability']
            confidence = pred_data['confidence']
            
            if prediction > self.config.SIGNAL_THRESHOLD['buy']:
                self.logger.info(f"{stock_id} 收到買入信號 (信心度: {confidence:.2f})")
                self._process_buy_signal(stock_id, prediction, confidence)
                
            elif prediction < self.config.SIGNAL_THRESHOLD['sell']:
                self.logger.info(f"{stock_id} 收到賣出信號 (信心度: {confidence:.2f})")
                if stock_id in self.positions:
                    self._process_sell_signal(stock_id, prediction, confidence)
### 3. 預測與信號處理
    def update_model_predictions(self, predictions: Dict[str, Dict]):
        """接收並更新來自04的模型預測結果
        predictions格式:
        {
            'stock_id': {
                'probability': float,  # 預測概率
                'confidence': float,   # 預測信心分數
                'features': Dict,      # 用於預測的特徵
                'model_version': str   # 模型版本
            }
        }
        """
        self.model_predictions = predictions
        self.prediction_history.append({
            'timestamp': datetime.now(),
            'predictions': predictions
        })
        
    def evaluate_prediction_signal(self, stock_id: str, prediction: Dict) -> Tuple[bool, str]:
        """評估預測信號的可信度"""
        if prediction['confidence'] < self.signal_confidence_threshold:
            return False, 'low_confidence'
            
        # 檢查特徵穩定性
        if not self._validate_features(prediction['features']):
            return False, 'unstable_features'
            
        # 檢查與歷史預測的一致性
        if not self._check_prediction_consistency(stock_id, prediction):
            return False, 'inconsistent_prediction'
            
        return True, 'valid_signal'
    
    def _process_buy_signal(self, stock_id: str, prediction: float, confidence: float):
        """處理買入信號"""
        # 檢查是否已有持倉
        if stock_id in self.positions:
            position = self.positions[stock_id]
            current_weight = position.market_value / self.total_portfolio_value
            
            # 根據信心度調整目標權重
            target_weight = min(
                current_weight * (1 + confidence),
                self.config.PORTFOLIO_PARAMS['single_stock_max_weight']
            )
            
            # 計算需要買入的股數
            price = position.current_price
            target_value = self.total_portfolio_value * target_weight
            additional_shares = int((target_value - position.market_value) / price)
            
            if additional_shares > 0:
                self.add_position(stock_id, additional_shares, price)
                
        else:
            # 新建倉位
            # 根據信心度決定初始權重
            initial_weight = confidence * self.config.PORTFOLIO_PARAMS['single_stock_max_weight']
            # 具體買入邏輯由06實現...

    def _process_sell_signal(self, stock_id: str):
        """處理賣出信號"""
        if stock_id in self.positions:
            self.remove_position(stock_id)
    
    def _validate_features(self, features: Dict) -> bool:
        """驗證特徵的穩定性"""
        required_features = [
            'trend_strength', 'volatility', 'volume_ratio',
            'rsi', 'macd', 'technical_score'
        ]
        
        # 檢查必要特徵是否存在
        if not all(feat in features for feat in required_features):
            return False
            
        # 檢查特徵值是否在合理範圍內
        if not self._check_feature_ranges(features):
            return False
            
        return True

    def _check_feature_ranges(self, features: Dict) -> bool:
        """檢查特徵值是否在合理範圍內"""
        try:
            # 檢查技術指標範圍
            if not 0 <= features.get('rsi', 0) <= 100:
                return False
                
            # 檢查波動率指標
            if features.get('volatility', 0) < 0:
                return False
                
            # 檢查量比指標
            if features.get('volume_ratio', 0) <= 0:
                return False
                
            # 檢查趨勢強度
            if abs(features.get('trend_strength', 0)) > 5:  # 假設正常範圍在 ±5 內
                return False
                
            return True
            
        except Exception as e:
            self.logger.error(f"特徵範圍檢查失敗: {str(e)}")
            return False

    def _check_prediction_consistency(self, stock_id: str, current_pred: Dict) -> bool:
        """檢查預測與歷史預測的一致性"""
        if not self.prediction_history:
            return True
            
        # 獲取最近的預測記錄
        recent_predictions = [
            ph['predictions'].get(stock_id, {})
            for ph in self.prediction_history[-3:]  # 檢查最近3次預測
        ]
        
        # 計算預測趨勢的穩定性
        if recent_predictions:
            probabilities = [p.get('probability', 0) for p in recent_predictions]
            prob_std = np.std(probabilities)
            
            # 如果預測波動太大,認為不穩定
            if prob_std > 0.2:  # 可配置的閾值
                return False
        
        return True
### 4. 投資組合維護與再平衡
    def get_rebalance_signals(self) -> List[Dict]:
        """檢查是否需要再平衡"""
        signals = []
        
        # 檢查是否有資產
        if self.total_portfolio_value <= 0:
            self.logger.warning("投資組合為空，無需再平衡")
            return signals
        
        # 檢查現金比例
        cash_ratio = self.cash / self.total_portfolio_value
        if cash_ratio < self.config.PORTFOLIO_PARAMS['cash_reserve_ratio']:
            signals.append({
                'type': 'cash',
                'action': 'raise_cash',
                'current_ratio': cash_ratio,
                'target_ratio': self.config.PORTFOLIO_PARAMS['cash_reserve_ratio']
            })
        
        # 檢查個股權重
        for stock_id, position in self.positions.items():
            weight = position.market_value / self.total_portfolio_value
            
            # 檢查是否超過最大權重
            if weight > self.config.PORTFOLIO_PARAMS['max_stock_weight']:
                signals.append({
                    'type': 'position',
                    'stock_id': stock_id,
                    'action': 'reduce',
                    'reason': 'exceed_max_weight',
                    'current_weight': weight,
                    'target_weight': self.config.PORTFOLIO_PARAMS['max_stock_weight']
                })
            
            # 檢查是否低於最小權重
            elif weight < self.config.PORTFOLIO_PARAMS['min_stock_weight']:
                signals.append({
                    'type': 'position',
                    'stock_id': stock_id,
                    'action': 'increase',
                    'reason': 'below_min_weight',
                    'current_weight': weight,
                    'target_weight': self.config.PORTFOLIO_PARAMS['min_stock_weight']
                })
            
            # 檢查停損點
            if position.returns < self.config.PORTFOLIO_PARAMS['stop_loss_threshold']:
                signals.append({
                    'type': 'risk',
                    'stock_id': stock_id,
                    'action': 'sell',
                    'reason': 'stop_loss',
                    'current_return': position.returns
                })
            
            # 檢查獲利點
            if position.returns > self.config.PORTFOLIO_PARAMS['take_profit_threshold']:
                signals.append({
                    'type': 'risk',
                    'stock_id': stock_id,
                    'action': 'sell',
                    'reason': 'take_profit',
                    'current_return': position.returns
                })
        
        # 檢查產業權重
        for industry, weight in self.industry_weights.items():
            if weight > self.config.PORTFOLIO_PARAMS['max_industry_weight']:
                signals.append({
                    'type': 'industry',
                    'industry': industry,
                    'action': 'reduce',
                    'reason': 'exceed_industry_weight',
                    'current_weight': weight,
                    'target_weight': self.config.PORTFOLIO_PARAMS['max_industry_weight']
                })
        
        return signals

    def execute_rebalance(self, signals: List[Dict]) -> bool:
        """執行再平衡操作"""
        try:
            for signal in signals:
                if signal['type'] == 'position':
                    if signal['action'] == 'reduce':
                        self._reduce_position(signal['stock_id'], signal['target_weight'])
                    elif signal['action'] == 'increase':
                        self._increase_position(signal['stock_id'], signal['target_weight'])
                
                elif signal['type'] == 'risk':
                    if signal['action'] == 'sell':
                        self.remove_position(signal['stock_id'])
                
                elif signal['type'] == 'industry':
                    if signal['action'] == 'reduce':
                        self._reduce_industry_weight(signal['industry'], signal['target_weight'])
                
                elif signal['type'] == 'cash':
                    if signal['action'] == 'raise_cash':
                        self._adjust_cash_level(signal['target_ratio'])
                
                self._record_transaction(
                    signal.get('stock_id', ''),
                    signal['action'],
                    0,  # shares will be calculated in specific methods
                    0.0,  # price will be calculated in specific methods
                    f"rebalance_{signal['reason']}"
                )
            
            return True
        
        except Exception as e:
            self.logger.error(f"執行再平衡時發生錯誤: {str(e)}")
            return False

    def _reduce_industry_weight(self, industry: str, target_weight: float):
        """降低產業權重"""
        current_weight = self.industry_weights.get(industry, 0)
        if current_weight <= target_weight:
            return
        
        # 計算需要減少的價值
        total_value = self.total_portfolio_value
        current_industry_value = current_weight * total_value
        target_industry_value = target_weight * total_value
        reduce_value = current_industry_value - target_industry_value
        
        # 取得該產業的所有持倉
        industry_positions = [
            (stock_id, position) 
            for stock_id, position in self.positions.items() 
            if position.industry == industry
        ]
        
        # 按照持倉市值比例減少
        for stock_id, position in industry_positions:
            reduce_ratio = position.market_value / current_industry_value
            position_reduce_value = reduce_value * reduce_ratio
            new_shares = position.shares * (1 - position_reduce_value / position.market_value)
            position.shares = int(new_shares)
            
            # 如果剩餘股數太少，直接清倉
            if position.shares < 100:  # 最小交易單位
                self.remove_position(stock_id)

    def _increase_position(self, stock_id: str, target_weight: float) -> bool:
        """增加持倉權重"""
        try:
            if stock_id not in self.positions:
                return False
                
            position = self.positions[stock_id]
            current_weight = position.market_value / self.total_portfolio_value
            
            if current_weight >= target_weight:
                return True
                
            additional_value = (target_weight - current_weight) * self.total_portfolio_value
            shares_to_buy = int(additional_value / position.current_price)
            
            if shares_to_buy > 0:
                return self.add_position(
                    stock_id=stock_id,
                    shares=shares_to_buy,
                    price=position.current_price,
                    industry=position.industry
                )
                
            return True
            
        except Exception as e:
            self.logger.error(f"增加持倉失敗: {str(e)}")
            return False

    def _adjust_cash_level(self, target_ratio: float):
        """調整現金水準"""
        current_ratio = self.cash / self.total_portfolio_value
        if current_ratio >= target_ratio:
            return
        
        target_cash = self.total_portfolio_value * target_ratio
        cash_needed = target_cash - self.cash
        
        # 按照持倉比例賣出股票籌措現金
        total_position_value = self.total_portfolio_value - self.cash
        for position in self.positions.values():
            sell_ratio = cash_needed / total_position_value
            shares_to_sell = int(position.shares * sell_ratio)
            if shares_to_sell > 0:
                proceeds = shares_to_sell * position.current_price * (
                    1 - self.config.TRADE_COST['commission_rate'] - 
                    self.config.TRADE_COST['tax_rate']
                )
                position.shares -= shares_to_sell
                self.cash += proceeds

    def _check_position_limits(self, stock_id: str, market_value: float, industry: str) -> bool:
        """檢查持倉限制"""
        total_value = self.total_portfolio_value + market_value
        
        # 檢查最大持股數量
        if len(self.positions) >= self.config.RISK_PARAMS['position_limit']:
            self.logger.warning("超過最大持股數量限制")
            return False
        
        # 檢查單一股票權重
        if market_value / total_value > self.config.PORTFOLIO_PARAMS['max_stock_weight']:
            self.logger.warning(f"{stock_id} 超過單一股票權重限制")
            return False
        
        # 檢查產業權重
        industry_value = sum(
            p.market_value for p in self.positions.values() 
            if p.industry == industry
        )
        if (industry_value + market_value) / total_value > self.config.PORTFOLIO_PARAMS['max_industry_weight']:
            self.logger.warning(f"{industry} 產業超過權重限制")
            return False
        
        return True

    def _update_industry_weights(self):
        """更新產業權重"""
        total_value = self.total_portfolio_value
        self.industry_weights = {}
        
        for position in self.positions.values():
            industry = position.industry
            if industry not in self.industry_weights:
                self.industry_weights[industry] = 0
            self.industry_weights[industry] += position.market_value / total_value

### 5. 指標計算與風險評估
    def calculate_portfolio_metrics(self) -> Dict:
        """計算投資組合指標"""
        try:
            metrics = {
                'total_value': self.total_portfolio_value,
                'cash_ratio': self.cash / self.total_portfolio_value,
                'total_return': self._calculate_total_return(),
                'daily_return': self._calculate_daily_return(),
                'volatility': self._calculate_portfolio_volatility(),
                'sharpe_ratio': self._calculate_sharpe_ratio(),
                'max_drawdown': self._calculate_max_drawdown(),
                'value_at_risk': self._calculate_var(),
                'industry_weights': self.industry_weights,
                'position_concentration': self._calculate_position_concentration()
            }
            self.performance = metrics
            return metrics
        except Exception as e:
            self.logger.error(f"計算投資組合指標時發生錯誤: {str(e)}")
            return {}

    def _calculate_total_return(self) -> float:
        """計算總報酬率"""
        if not self.history:
            return 0.0
        initial_value = self.history[0]['portfolio_value']
        return (self.total_portfolio_value - initial_value) / initial_value

    def _calculate_daily_return(self) -> float:
        """計算日報酬率"""
        if len(self.history) < 2:
            return 0.0
        prev_value = self.history[-1]['portfolio_value']
        curr_value = self.total_portfolio_value
        return (curr_value - prev_value) / prev_value if prev_value > 0 else 0.0
    
    def _calculate_portfolio_volatility(self) -> float:
        """計算投資組合波動率"""
        if not self.history:
            return 0.0
        returns = pd.Series([record['daily_return'] for record in self.history])
        return returns.std() * np.sqrt(252)  # 年化波動率
    
    def _calculate_sharpe_ratio(self) -> float:
        """計算夏普比率"""
        if not self.history:
            return 0.0
        returns = pd.Series([record['daily_return'] for record in self.history])
        excess_returns = returns - self.config.PORTFOLIO_PARAMS['risk_free_rate'] / 252
        return np.sqrt(252) * excess_returns.mean() / returns.std() if returns.std() > 0 else 0.0
    
    def _calculate_max_drawdown(self) -> float:
        """計算最大回撤"""
        if not self.history:
            return 0.0
        values = pd.Series([record['portfolio_value'] for record in self.history])
        cummax = values.cummax()
        drawdown = (values - cummax) / cummax
        return drawdown.min()
    
    def _calculate_var(self, 
                      confidence_level: float = None) -> float:
        """計算Value at Risk"""
        if not self.history:
            return 0.0
        confidence_level = confidence_level or self.config.RISK_PARAMS['var_confidence_level']
        returns = pd.Series([record['daily_return'] for record in self.history])
        return np.percentile(returns, (1 - confidence_level) * 100)
    
    def _calculate_position_concentration(self) -> float:
        """計算持倉集中度(赫芬達爾指數)"""
        if not self.positions:
            return 0.0
        weights = [p.market_value / self.total_portfolio_value 
                  for p in self.positions.values()]
        return sum(w * w for w in weights)

### 6. 狀態管理與存儲
    def get_portfolio_status(self) -> Dict:
        """提供當前投資組合狀態供分析使用"""
        return {
            'positions': {
                stock_id: {
                    'shares': pos.shares,
                    'avg_cost': pos.avg_cost,
                    'current_price': pos.current_price,
                    'market_value': pos.market_value,
                    'returns': pos.returns,
                    'industry': pos.industry,
                    'entry_time': pos.entry_time.strftime('%Y-%m-%d %H:%M:%S'),
                    'last_update': pos.last_update.strftime('%Y-%m-%d %H:%M:%S'),
                }
                for stock_id, pos in self.positions.items()
            },
            'cash': self.cash,
            'total_value': self.total_portfolio_value,
            'performance': self.performance,
            'industry_weights': self.industry_weights,
        }
    
    def save_portfolio_state(self):
        """保存投資組合狀態"""
        try:
            state = {
                'positions': {
                    stock_id: {
                        'shares': pos.shares,
                        'avg_cost': pos.avg_cost,
                        'current_price': pos.current_price,
                        'industry': pos.industry,
                        'entry_time': pos.entry_time.strftime('%Y-%m-%d %H:%M:%S'),
                        'last_update': pos.last_update.strftime('%Y-%m-%d %H:%M:%S'),
                        'trade_history': pos.trade_history
                    }
                    for stock_id, pos in self.positions.items()
                },
                'cash': self.cash,
                'industry_weights': self.industry_weights,
                'history': self.history,
                'performance': self.performance,
                'last_update': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }
            
            file_path = Path(f"{self.config.BASE_DIR}/{self.config.PORTFOLIO_DIR}/portfolio_state.json")
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(state, f, ensure_ascii=False, indent=2)
            
            self.logger.info("投資組合狀態已保存")
            return True
            
        except Exception as e:
            self.logger.error(f"保存投資組合狀態時發生錯誤: {str(e)}")
            return False

    def load_portfolio_state(self) -> bool:
        """載入投資組合狀態"""
        try:
            file_path = Path(f"{self.config.BASE_DIR}/{self.config.PORTFOLIO_DIR}/portfolio_state.json")
            if not file_path.exists():
                self.logger.warning("找不到投資組合狀態文件")
                # 如果文件不存在，設置初始資金_validate_input
                self.cash = 1000000  # 預設初始資金
                self._record_transaction('', 'initialize', 0, 0.0, 'initial_funding')
                return False
            
            with open(file_path, 'r', encoding='utf-8') as f:
                state = json.load(f)
            
            # 重建持倉
            self.positions = {}
            for stock_id, pos_data in state['positions'].items():
                position = Position(
                    stock_id=stock_id,
                    shares=pos_data['shares'],
                    avg_cost=pos_data['avg_cost'],
                    industry=pos_data['industry'],
                    entry_time=datetime.strptime(pos_data['entry_time'], '%Y-%m-%d %H:%M:%S')
                )
                position.current_price = pos_data['current_price']
                position.last_update = datetime.strptime(pos_data['last_update'], '%Y-%m-%d %H:%M:%S')
                position.trade_history = pos_data['trade_history']
                self.positions[stock_id] = position
            
            self.cash = state['cash']
            self.industry_weights = state['industry_weights']
            self.history = state['history']
            self.performance = state['performance']
            
            # 確保至少有現金
            if self.total_portfolio_value <= 0:
                self.logger.warning("載入的投資組合沒有資產，重設初始資金")
                self.cash = 1000000  # 預設初始資金
                self._record_transaction('', 'initialize', 0, 0.0, 'reset_funding')
            
            self.logger.info("投資組合狀態已載入")
            return True
            
        except Exception as e:
            self.logger.error(f"載入投資組合狀態時發生錯誤: {str(e)}")
            # 發生錯誤時也確保有初始資金
            self.cash = 1000000
            self._record_transaction('', 'initialize', 0, 0.0, 'error_funding')
            return False

    def _record_transaction(self, 
                          stock_id: str, 
                          action: str, 
                          shares: int, 
                          price: float, 
                          reason: str):
        """記錄交易"""
        transaction = {
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'stock_id': stock_id,
            'action': action,
            'shares': shares,
            'price': price,
            'value': shares * price,
            'reason': reason,
            'portfolio_value': self.total_portfolio_value,
            'daily_return': self._calculate_daily_return()
        }
        self.history.append(transaction)
    
    @property
    def total_portfolio_value(self) -> float:
        """計算總資產"""
        return sum(position.market_value for position in self.positions.values()) + self.cash

    def _validate_input(self, stock_id: str, shares: int, price: float, industry: str) -> bool:
        """驗證輸入參數的合法性"""
        try:
            if not isinstance(stock_id, str) or not stock_id:
                self.logger.error("無效的股票代碼")
                return False
                
            if not isinstance(shares, int) or shares <= 0:
                self.logger.error("無效的股數")
                return False
                
            if not isinstance(price, (int, float)) or price <= 0:
                self.logger.error("無效的價格")
                return False
                
            if not isinstance(industry, str) or not industry:
                self.logger.error("無效的產業類別")
                return False
                
            return True
            
        except Exception as e:
            self.logger.error(f"輸入驗證失敗: {str(e)}")
            return False

In [40]:
def test_portfolio():
    """測試投資組合管理系統"""
    # 1. 基礎功能測試
    config = PortfolioConfig()
    portfolio = PortfolioManager(config)
    
    # 2. 初始化檢查
    status = portfolio.get_portfolio_status()
    print(f"初始總資產: {status['total_value']:.2f}")
    print(f"現金比例: {status['cash']/status['total_value']:.2%}")
    
    # 3. 指標計算測試
    metrics = portfolio.calculate_portfolio_metrics()
    print("\n投資組合指標:")
    for key, value in metrics.items():
        if isinstance(value, (int, float)):
            print(f"{key}: {value:.2f}")
    
    # 4. 狀態保存和載入測試
    portfolio.save_portfolio_state()
    load_success = portfolio.load_portfolio_state()
    print(f"\n狀態載入: {'成功' if load_success else '失敗'}")
    
    return portfolio

In [41]:
if __name__ == "__main__":
    portfolio = test_portfolio()

2024-11-07 15:59:47,086 - PortfolioManager - INFO - 投資組合狀態已保存
2024-11-07 15:59:47,097 - PortfolioManager - INFO - 投資組合狀態已載入


初始總資產: 1000000.00
現金比例: 100.00%

投資組合指標:
total_value: 1000000.00
cash_ratio: 1.00
total_return: 0.00
daily_return: 0.00
volatility: nan
sharpe_ratio: 0.00
max_drawdown: 0.00
value_at_risk: 0.00
position_concentration: 0.00

狀態載入: 成功
