In [1]:
# Essential Imports
import requests
import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import mplfinance as mpf
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import ta
import logging
import gc
import traceback
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Tuple
from sklearn.metrics import mean_squared_error, mean_absolute_error
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import os
from collections import defaultdict
import matplotlib.gridspec as gridspec
from ta.volatility import BollingerBands
from ta.trend import MACD
import warnings
from dotenv import load_dotenv
from scipy.stats import percentileofscore
import optuna
from functools import partial


# Suppress mplfinance warnings for too much data
warnings.filterwarnings("ignore", category=UserWarning, module="mplfinance")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    filename='checklstm2_trading_model.log',
    filemode='a',
    format='%(asctime)s:%(levelname)s:%(message)s'
)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Load environment variables
load_dotenv()

# API Configuration
API_KEY = os.getenv("POLYGON_API_KEY")
if not API_KEY:
    raise ValueError("Missing Polygon API KEY. Please check your .env file.")


@dataclass
class TrainingHistory:
    """Tracks training metrics during model training."""
    loss_history: List[float] = field(default_factory=list)
    validation_loss_history: List[float] = field(default_factory=list)
    direction_accuracy_history: List[float] = field(default_factory=list)
    long_accuracy_history: List[float] = field(default_factory=list)
    short_accuracy_history: List[float] = field(default_factory=list)

    def update(self, epoch: int, loss: float, val_loss: float, 
              direction_accuracy: float = None, 
              long_accuracy: float = None, 
              short_accuracy: float = None):
        """Update training history with new metrics."""
        self.loss_history.append(loss)
        self.validation_loss_history.append(val_loss)
        if direction_accuracy is not None:
            self.direction_accuracy_history.append(direction_accuracy)
        if long_accuracy is not None:
            self.long_accuracy_history.append(long_accuracy)
        if short_accuracy is not None:
            self.short_accuracy_history.append(short_accuracy)

class ModelEvaluator:
    def __init__(self, prediction_threshold: float = 0.001):
        self.prediction_threshold = prediction_threshold
        self.mse_scores = []
        self.mae_scores = []
        self.seasonal_errors = []
        # Track prediction errors for adaptive correction
        self.error_history = defaultdict(list)
        self.max_history_size = 100

    def calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray, 
                         current_prices: np.ndarray = None,
                         ticker: str = None) -> Dict[str, float]:
        """Calculate various error metrics based on predictions with direction-specific analysis."""
        metrics = {}
        
        # Ensure inputs are 1-D and matching length
        y_true = y_true.flatten()
        y_pred = y_pred.flatten()
        
        if len(y_true) != len(y_pred):
            raise ValueError("y_true and y_pred must have the same length")
            
        # Handle invalid values
        if np.isnan(y_true).any() or np.isnan(y_pred).any():
            raise ValueError("Inputs contain NaN values")
        if np.isinf(y_true).any() or np.isinf(y_pred).any():
            raise ValueError("Inputs contain infinite values")
        
        # Track errors for this ticker if provided
        if ticker:
            errors = y_pred - y_true
            self.error_history[ticker].extend(errors.tolist())
            # Trim history if too long
            if len(self.error_history[ticker]) > self.max_history_size:
                self.error_history[ticker] = self.error_history[ticker][-self.max_history_size:]
            
            metrics['mean_error'] = np.mean(errors)
            metrics['error_std'] = np.std(errors)
            
        try:
            # Direction Accuracy (overall)
            direction_true = np.sign(y_true[1:] - y_true[:-1])
            direction_pred = np.sign(y_pred[1:] - y_pred[:-1])
            direction_accuracy = np.mean(direction_true == direction_pred) * 100
            metrics['direction_accuracy'] = direction_accuracy
            
            # Direction Accuracy (separated by up/down)
            up_indices = np.where(direction_true > 0)[0]
            down_indices = np.where(direction_true < 0)[0]
            
            if len(up_indices) > 0:
                up_direction_accuracy = np.mean(direction_pred[up_indices] > 0) * 100
                metrics['up_direction_accuracy'] = up_direction_accuracy
            else:
                metrics['up_direction_accuracy'] = 0
                
            if len(down_indices) > 0:
                down_direction_accuracy = np.mean(direction_pred[down_indices] < 0) * 100
                metrics['down_direction_accuracy'] = down_direction_accuracy
            else:
                metrics['down_direction_accuracy'] = 0
            
            # Directional Bias (higher means bias toward predicting up moves)
            if current_prices is not None:
                pred_direction_vs_current = np.sign(y_pred - current_prices)
                up_pred_pct = np.mean(pred_direction_vs_current > 0) * 100
                down_pred_pct = np.mean(pred_direction_vs_current < 0) * 100
                metrics['up_prediction_pct'] = up_pred_pct
                metrics['down_prediction_pct'] = down_pred_pct
                metrics['direction_bias'] = up_pred_pct - down_pred_pct
            
            # Calculate prediction range metrics
            metrics['true_price_range'] = np.max(y_true) - np.min(y_true)
            metrics['pred_price_range'] = np.max(y_pred) - np.min(y_pred)
            metrics['range_ratio'] = metrics['pred_price_range'] / metrics['true_price_range'] if metrics['true_price_range'] > 0 else 0
            
            # Standard error metrics
            metrics['mse'] = mean_squared_error(y_true, y_pred)
            metrics['mae'] = mean_absolute_error(y_true, y_pred)
            metrics['rmse'] = np.sqrt(metrics['mse'])
            
            # Magnitude error by direction
            if len(up_indices) > 0:
                metrics['up_move_mae'] = mean_absolute_error(
                    y_true[1:][up_indices], y_pred[1:][up_indices])
            if len(down_indices) > 0:
                metrics['down_move_mae'] = mean_absolute_error(
                    y_true[1:][down_indices], y_pred[1:][down_indices])
            
        except Exception as e:
            logger.error(f"Error in calculate_metrics: {e}")
            raise
            
        return metrics

    def get_adaptive_correction(self, ticker: str) -> float:
        """Get adaptive correction factor based on historical errors for this ticker."""
        if ticker in self.error_history and len(self.error_history[ticker]) > 10:
            # Use recent history for correction
            recent_errors = self.error_history[ticker][-10:]
            return np.mean(recent_errors) * 0.5  # Partial correction to avoid overcompensation
        return 0.0

    def calculate_trade_performance_metrics(self, trades: List[Dict], 
                                           initial_capital: float = 10000.0,
                                           include_transaction_costs: bool = True) -> Dict[str, float]:
        """Calculate comprehensive performance metrics for trades with realistic costs."""
        if not trades:
            return {
                'total_trades': 0,
                'total_profit': 0.0,
                'win_rate': 0.0,
                'loss_rate': 0.0,
                'average_profit_per_trade': 0.0,
                'maximum_drawdown': 0.0,
                'profit_factor': 0.0,
                'sharpe_ratio': 0.0
            }

        # Initialize metrics tracking
        equity_curve = [initial_capital]
        daily_returns = defaultdict(float)
        long_profits = []
        short_profits = []
        trade_durations = []

        # Process each trade
        for trade in trades:
            # Calculate trade profit (should already include transaction costs if include_transaction_costs is True)
            if 'profit' in trade:
                profit = trade['profit']
            else:
                # Calculate from entry/exit prices if profit not provided
                if trade['position'] == 'enter_long':
                    profit = trade['exit_price'] - trade['entry_price']
                else:  # enter_short
                    profit = trade['entry_price'] - trade['exit_price']
                
                # Apply transaction costs if requested
                if include_transaction_costs and 'transaction_costs' in trade:
                    profit -= trade['transaction_costs']
            
            # Update equity curve
            equity_curve.append(equity_curve[-1] + profit)
            
            # Track daily returns
            trade_date = trade['exit_time'].date()
            daily_returns[trade_date] += profit
            
            # Track by position type
            if trade['position'] == 'enter_long':
                long_profits.append(profit)
            else:
                short_profits.append(profit)
                
            # Track trade duration
            if 'duration' in trade:
                trade_durations.append(trade['duration'])

        # Calculate overall metrics
        profits = long_profits + short_profits
        total_profit = sum(profits)
        winning_trades = [p for p in profits if p > 0]
        losing_trades = [p for p in profits if p < 0]
        
        win_rate = len(winning_trades) / len(profits) * 100 if profits else 0
        loss_rate = len(losing_trades) / len(profits) * 100 if profits else 0
        
        # Calculate separate metrics for long and short positions
        long_win_rate = len([p for p in long_profits if p > 0]) / len(long_profits) * 100 if long_profits else 0
        short_win_rate = len([p for p in short_profits if p > 0]) / len(short_profits) * 100 if short_profits else 0
        
        # Calculate drawdown
        drawdowns = [equity_curve[i] - max(equity_curve[:i+1]) for i in range(len(equity_curve))]
        max_drawdown = abs(min(drawdowns)) if drawdowns else 0
        
        # Calculate Sharpe ratio from daily returns
        daily_returns_list = list(daily_returns.values())
        avg_daily_return = np.mean(daily_returns_list) if daily_returns_list else 0
        std_daily_return = np.std(daily_returns_list) if len(daily_returns_list) > 1 else 1e-6
        sharpe_ratio = (avg_daily_return / std_daily_return * np.sqrt(252)
                       if std_daily_return > 0 else 0)

        return {
            'total_trades': len(trades),
            'total_profit': round(total_profit, 2),
            'win_rate': round(win_rate, 2),
            'loss_rate': round(loss_rate, 2),
            'long_trades': len(long_profits),
            'short_trades': len(short_profits),
            'long_win_rate': round(long_win_rate, 2),
            'short_win_rate': round(short_win_rate, 2),
            'average_profit_per_trade': round(np.mean(profits), 2) if profits else 0,
            'average_trade_duration': round(np.mean(trade_durations), 2) if trade_durations else 0,
            'maximum_drawdown': round(max_drawdown, 2),
            'profit_factor': round(sum(winning_trades) / abs(sum(losing_trades)) if losing_trades and sum(losing_trades) != 0 else float('inf'), 2),
            'sharpe_ratio': round(sharpe_ratio, 2)
        }

class AdaptiveSignalGenerator:
    def __init__(self, ticker: str):
        self.ticker = ticker
        self.stock_profiles = {
            'MSFT': {
                'base_entry_threshold': 0.0008,  # Increased from 0.0006
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0)
                'base_exit_threshold': 0.0006,  # Increased from 0.0005
                'base_stop_loss': 0.0012,
                'atr_multiplier': 1.1,  # Increased from 1.0
                'min_hold_time': 2,
                'max_daily_trades': 250,
                'position_size': 0.2,
                'volatility_threshold': 0.7,
                'volume_threshold': 0.6,
                'trend_threshold': 0.015
            },
            'GOOGL': {
                'base_entry_threshold': 0.0010,  # Increased from 0.0008
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0)
                'base_exit_threshold': 0.0007,  # Increased from 0.0006
                'base_stop_loss': 0.0018,
                'atr_multiplier': 1.2,  # Increased from 1.1
                'min_hold_time': 2,
                'max_daily_trades': 280,
                'position_size': 0.2,
                'volatility_threshold': 0.8,
                'volume_threshold': 0.5,
                'trend_threshold': 0.018
            },
            'TSLA': {
                'base_entry_threshold': 0.0030,  # Increased from 0.0025
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0)
                'base_exit_threshold': 0.0023,  # Increased from 0.002
                'base_stop_loss': 0.0035,
                'atr_multiplier': 1.8,
                'min_hold_time': 1,
                'max_daily_trades': 180,
                'position_size': 0.1,
                'volatility_threshold': 1.3,
                'volume_threshold': 0.7,
                'trend_threshold': 0.025
            },
            'NVDA': {
                'base_entry_threshold': 0.0025,  # Increased from 0.002
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0)
                'base_exit_threshold': 0.0018,  # Increased from 0.0015
                'base_stop_loss': 0.003,
                'atr_multiplier': 1.6,  # Increased from 1.5
                'min_hold_time': 1,
                'max_daily_trades': 220,
                'position_size': 0.12,
                'volatility_threshold': 1.1,
                'volume_threshold': 0.6,
                'trend_threshold': 0.02
            },
            'TQQQ': {
                'base_entry_threshold': 0.0045,  # Increased from 0.004
                'short_entry_threshold_factor': 0.85,  # More strongly favor shorts (was 1.0)
                'base_exit_threshold': 0.0035,  # Increased from 0.003
                'base_stop_loss': 0.006,
                'atr_multiplier': 2.5,
                'min_hold_time': 3,
                'max_daily_trades': 120,
                'position_size': 0.06,
                'volatility_threshold': 2.0,
                'volume_threshold': 0.8,
                'trend_threshold': 0.04
            },
            'SQQQ': {
                'base_entry_threshold': 0.0045,  # Increased from 0.004
                'short_entry_threshold_factor': 0.85,  # More strongly favor shorts (was 1.0)
                'base_exit_threshold': 0.0035,  # Increased from 0.003
                'base_stop_loss': 0.006,
                'atr_multiplier': 2.5,
                'min_hold_time': 2,
                'max_daily_trades': 100,
                'position_size': 0.06,
                'volatility_threshold': 2.0,
                'volume_threshold': 0.8,
                'trend_threshold': 0.04
            },
            'QLD': {
                'base_entry_threshold': 0.0030,  # Increased from 0.0025
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0)
                'base_exit_threshold': 0.0022,  # Increased from 0.0018
                'base_stop_loss': 0.0038,  # Increased from 0.0035
                'atr_multiplier': 1.9,  # Increased from 1.8
                'min_hold_time': 1,
                'max_daily_trades': 180,
                'position_size': 0.10,
                'volatility_threshold': 1.5,
                'volume_threshold': 0.7,
                'trend_threshold': 0.03
            },
            'PSQ': {
                'base_entry_threshold': 0.0030,  # Increased from 0.0025
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0) 
                'base_exit_threshold': 0.0020,  # Increased from 0.0018
                'base_stop_loss': 0.0038,  # Increased from 0.0035
                'atr_multiplier': 1.7,  # Increased from 1.6
                'min_hold_time': 1,
                'max_daily_trades': 180,
                'position_size': 0.15,
                'volatility_threshold': 1.2,
                'volume_threshold': 0.6,
                'trend_threshold': 0.025
            },
            'QQQ': {
                'base_entry_threshold': 0.0004,  # Increased from 0.0003
                'short_entry_threshold_factor': 0.9,  # Slightly favor shorts (was 1.0)
                'base_exit_threshold': 0.0003,  # Increased from 0.0002 
                'base_stop_loss': 0.0012,  # Increased from 0.001
                'atr_multiplier': 1.0,  # Increased from 0.9
                'min_hold_time': 1,
                'max_daily_trades': 250,
                'position_size': 0.25,
                'volatility_threshold': 0.3,
                'volume_threshold': 0.0,
                'trend_threshold': 0.006
            }
        }
        
        # Default values for any ticker not explicitly listed
        self.default_profile = {
            'base_entry_threshold': 0.0018,  # Increased from 0.0015
            'short_entry_threshold_factor': 0.9,  # Favor shorts (was 1.0)
            'base_exit_threshold': 0.0012,  # Increased from 0.001
            'base_stop_loss': 0.0028,  # Increased from 0.0025
            'atr_multiplier': 1.4,  # Increased from 1.3
            'min_hold_time': 2,
            'max_daily_trades': 180,
            'position_size': 0.15,
            'volatility_threshold': 1.0,
            'volume_threshold': 1.0,
            'trend_threshold': 0.02
        }
        
        # Use stock-specific profile if available, otherwise use default
        self.params = self.stock_profiles.get(ticker, self.default_profile)
        self._initialize_parameters()
        
        # For storing prediction confidence levels - equalized for better balance
        self.direction_confidence = {'up': 0.75, 'down': 0.75}
        
    def _initialize_parameters(self):
        """Initialize trading parameters from stock profile."""
        self.base_entry_threshold = self.params['base_entry_threshold']
        # Use full value for short entry threshold (removed the 0.7 reduction factor)
        self.short_entry_threshold_factor = self.params['short_entry_threshold_factor']
        self.base_exit_threshold = self.params['base_exit_threshold']
        self.base_stop_loss = self.params['base_stop_loss']
        self.atr_multiplier = self.params['atr_multiplier']
        self.min_hold_time = self.params['min_hold_time']
        self.volatility_threshold = self.params['volatility_threshold']
        
        self.position = None
        self.entry_price = None
        self.entry_time = None
        self.current_atr = None
        self.daily_trades = 0
        self.last_trade_date = None
        
    def calculate_atr(self, market_data: pd.DataFrame) -> float:
        """Calculate Average True Range."""
        high = market_data['high'].values
        low = market_data['low'].values
        close = pd.Series(market_data['close'].values)
        
        tr1 = pd.Series(high - low)
        tr2 = pd.Series(abs(high - close.shift()))
        tr3 = pd.Series(abs(low - close.shift()))
        
        tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
        atr = tr.rolling(window=14).mean().iloc[-1]
        
        return atr
        
    def calculate_thresholds(self, current_volatility: float, 
                            market_regime: str = 'neutral',
                            prediction_confidence: Dict[str, float] = None) -> Dict[str, float]:
        """Calculate adaptive thresholds based on current volatility, market regime, and prediction confidence."""
        volatility_factor = current_volatility * self.atr_multiplier
        
        # Base thresholds - make them equal for long and short entries
        long_entry_threshold = self.base_entry_threshold * (1 + volatility_factor)
        # NEW: Make short entry threshold slightly lower to counter bullish bias
        short_entry_threshold = self.base_entry_threshold * (1 + volatility_factor) * 0.9
        exit_threshold = self.base_exit_threshold * (1 + volatility_factor)
        stop_loss = self.base_stop_loss * (1 + volatility_factor)
        
        # Adjust based on prediction confidence if available
        if prediction_confidence:
            # Scale thresholds inversely by confidence
            long_entry_threshold *= (1.2 - prediction_confidence.get('up', 0.5))
            short_entry_threshold *= (1.2 - prediction_confidence.get('down', 0.5))
        
        # NEW: Adjust based on market regime - more balanced approach
        if market_regime == 'high_volatility':
            long_entry_threshold *= 1.3   # More conservative for longs in high volatility
            short_entry_threshold *= 0.7   # More aggressive for shorts in high volatility
            stop_loss *= 1.1              # Wider stop loss in high volatility
        elif market_regime == 'trending_up':
            long_entry_threshold *= 1.0    # Neutral for longs in uptrend
            short_entry_threshold *= 0.8   # More aggressive for shorts in uptrend
        elif market_regime == 'trending_down':
            long_entry_threshold *= 1.3    # More conservative for longs in downtrend
            short_entry_threshold *= 0.6   # Much more aggressive for shorts in downtrend
        
        return {
            'long_entry': long_entry_threshold,
            'short_entry': short_entry_threshold,
            'exit': exit_threshold,
            'stop_loss': stop_loss
        }
        
    def detect_market_regime(self, market_data: pd.DataFrame) -> str:
        """Detect current market regime based on price action and indicators."""
        if len(market_data) < 20:
            return 'neutral'
            
        # Calculate volatility regime
        returns = market_data['close'].pct_change().dropna()
        current_volatility = returns.iloc[-20:].std() * np.sqrt(252)  # Annualized
        avg_volatility = returns.iloc[-60:].std() * np.sqrt(252)
        
        # Calculate trend
        ma5 = market_data['close'].rolling(5).mean().iloc[-1]
        ma20 = market_data['close'].rolling(20).mean().iloc[-1]
        trend_strength = (ma5 - ma20) / ma20
        
        # Determine regime
        if current_volatility > avg_volatility * 1.3:
            regime = 'high_volatility'
        elif trend_strength > 0.02:
            regime = 'trending_up'
        elif trend_strength < -0.02:
            regime = 'trending_down'
        else:
            regime = 'neutral'
            
        return regime
        
    def check_trend_condition(self, market_data: pd.DataFrame) -> Dict[str, bool]:
        """Check if trend conditions are favorable, return separate conditions for long and short."""
        if 'MA5' not in market_data.columns or 'SMA_20' not in market_data.columns:
            # Calculate if not present
            market_data['MA5'] = market_data['close'].rolling(5).mean()
            market_data['SMA_20'] = market_data['close'].rolling(20).mean()
        
        ma5 = market_data['MA5'].iloc[-1]
        ma20 = market_data['SMA_20'].iloc[-1]
        
        trend_strength = abs((ma5 - ma20) / ma20)
        trend_direction = np.sign(ma5 - ma20)
        
        # Different trend conditions for long and short
        if self.ticker in ['TSLA', 'NVDA', 'TQQQ', 'SQQQ']:
            trend_threshold = self.params['trend_threshold'] * 1.5
        else:
            trend_threshold = self.params['trend_threshold']
            
        # More permissive trend check for shorts
        long_trend_ok = trend_strength <= trend_threshold
        short_trend_ok = True  # Allow shorts regardless of trend strength
        
        return {'long': long_trend_ok, 'short': short_trend_ok}
            
    def check_volume_condition(self, market_data: pd.DataFrame) -> bool:
        """Check if volume conditions are favorable."""
        current_volume = market_data['volume'].iloc[-1]
        avg_volume = market_data['volume'].rolling(20).mean().iloc[-1]
        volume_ratio = current_volume / avg_volume
        
        # If volume_threshold is 0, disable volume check (always return True)
        if self.params['volume_threshold'] <= 0:
            return True
            
        return volume_ratio >= self.params['volume_threshold']
        
    def update_direction_confidence(self, metrics: Dict[str, float]):
        """Update direction confidence based on recent model metrics."""
        if 'up_direction_accuracy' in metrics and 'down_direction_accuracy' in metrics:
            # Scale between 0.5 and 1.0
            self.direction_confidence['up'] = 0.5 + (metrics['up_direction_accuracy'] / 200)
            self.direction_confidence['down'] = 0.5 + (metrics['down_direction_accuracy'] / 200)
    
    def _create_signal(self, ticker: str, timestamp: pd.Timestamp, price: float) -> Dict:
        """Create a base signal dictionary."""
        return {
            'ticker': ticker,
            'timestamp': timestamp,
            'price': price,
            'action': None,
            'position': None,
            'position_size': self.params['position_size'],
            'thresholds': {
                'long_entry': self.base_entry_threshold,
                'short_entry': self.base_entry_threshold * self.short_entry_threshold_factor,
                'exit': self.base_exit_threshold,
                'stop_loss': self.base_stop_loss
            },
            'market_conditions': {
                'atr': self.current_atr,
                'daily_trades': self.daily_trades
            }
        }
            
    def generate_signals(self, ticker: str, current_price: float, 
                        predicted_price: float, timestamp: pd.Timestamp,
                        market_data: pd.DataFrame,
                        order_book_data: Dict = None,
                        prediction_confidence: Dict[str, float] = None) -> Dict:
        """Generate trading signals with balanced thresholds for long/short."""
        # Reset daily trades if new day
        current_date = timestamp.date()
        if self.last_trade_date != current_date:
            self.daily_trades = 0
            self.last_trade_date = current_date
            
        # Check trade frequency limit
        if self.daily_trades >= self.params['max_daily_trades']:
            logger.info(f"Trade rejected due to exceeding max daily trades ({self.daily_trades}/{self.params['max_daily_trades']})")
            return self._create_signal(ticker, timestamp, current_price)
            
        # Check minimum hold time
        if self.position and self.entry_time:
            hold_time = (timestamp - self.entry_time).total_seconds() / 60
            if hold_time < self.params['min_hold_time']:
                logger.info(f"Trade rejected due to minimum hold time restriction ({hold_time:.1f} min < {self.params['min_hold_time']} min)")
                return self._create_signal(ticker, timestamp, current_price)
        
        # Variables for trend and volume checks - initialize outside the if blocks
        trend_conditions = None
        volume_ok = False
        
        try:
            # Calculate ATR and market regime
            self.current_atr = self.calculate_atr(market_data)
            market_regime = self.detect_market_regime(market_data)
            
            # Safe logging for ATR
            atr_value = 0.0 if self.current_atr is None else self.current_atr
            logger.info(f"Current market regime: {market_regime}, ATR: {atr_value:.6f}")
            
            # Calculate recent prediction accuracy to adjust bias correction
            price_changes = market_data['close'].pct_change().dropna()
            mean_price_change = price_changes[-20:].mean() if len(price_changes) >= 20 else 0
            logger.info(f"Recent mean price change: {mean_price_change*100:.4f}%")
            
            # Apply reduced downward bias correction
            if market_regime in ['trending_up', 'high_volatility']:
                downward_bias_correction = current_price * 0.0025  # Reduced from 0.5% to 0.25%
            else:
                downward_bias_correction = current_price * 0.001   # Reduced from 0.2% to 0.1%
            
            # Apply additional adaptive correction based on recent price movement
            if mean_price_change > 0:
                # If prices have been rising, apply reduced correction
                adaptive_factor = min(0.0015, mean_price_change)  # Reduced from 0.3% to 0.15%
                downward_bias_correction += current_price * adaptive_factor
            
            logger.info(f"Applied downward bias correction: {downward_bias_correction:.4f} ({(downward_bias_correction/current_price)*100:.4f}%)")
            
            # Use two different predictions for different directions
            original_prediction = predicted_price
            corrected_prediction = predicted_price - downward_bias_correction
            
            logger.info(f"Original prediction: {original_prediction:.4f}, Corrected: {corrected_prediction:.4f}")
            
            # Use provided prediction confidence or default to internal values
            confidence = prediction_confidence or self.direction_confidence
            
            # Safe confidence values
            up_conf = confidence.get('up', 0.5) if confidence else 0.5
            down_conf = confidence.get('down', 0.5) if confidence else 0.5
            logger.info(f"Direction confidence - Up: {up_conf:.2f}, Down: {down_conf:.2f}")
            
            # Get adaptive thresholds based on current conditions
            thresholds = self.calculate_thresholds(
                atr_value, 
                market_regime,
                confidence
            )
            
            # Safe logging for thresholds
            if thresholds:
                logger.info(f"Thresholds - Long entry: {thresholds.get('long_entry', 0)*100:.4f}%, " + 
                        f"Short entry: {thresholds.get('short_entry', 0)*100:.4f}%")
            else:
                logger.info("Thresholds calculation failed, using defaults")
                thresholds = {
                    'long_entry': 0.001,
                    'short_entry': 0.001,
                    'exit': 0.001,
                    'stop_loss': 0.002
                }
            
            # Calculate price change for LONG - using corrected prediction
            long_price_change_pct = (corrected_prediction - current_price) / current_price
            
            # Calculate price change for SHORT - using original prediction
            short_price_change_pct = (original_prediction - current_price) / current_price
            
            logger.info(f"Price change % - Long: {long_price_change_pct*100:.4f}%, Short: {short_price_change_pct*100:.4f}%")
            
            # Create base signal
            signal = self._create_signal(ticker, timestamp, current_price)
            signal['market_regime'] = market_regime
            signal['thresholds'] = thresholds
            
            # Pre-calculate trend and volume conditions once
            trend_conditions = self.check_trend_condition(market_data)
            volume_ok = self.check_volume_condition(market_data)
            
            # Safe logging for trend conditions
            if trend_conditions:
                logger.info(f"Trend conditions - Long: {trend_conditions.get('long', False)}, Short: {trend_conditions.get('short', False)}")
            else:
                logger.info("Trend conditions calculation failed")
                trend_conditions = {'long': False, 'short': False}
                
            logger.info(f"Volume condition: {volume_ok}")
            
            # Handle entry signals with balanced thresholds and detailed diagnostics
            if self.position is None:
                # Check and log long entry conditions
                meets_long_threshold = long_price_change_pct > thresholds['long_entry']
                meets_long_trend = trend_conditions.get('long', False)
                
                if not meets_long_threshold:
                    logger.info(f"LONG signal rejected - Price change {long_price_change_pct*100:.4f}% below threshold {thresholds['long_entry']*100:.4f}%")
                elif not meets_long_trend:
                    logger.info(f"LONG signal rejected - Failed trend condition")
                elif not volume_ok:
                    logger.info(f"LONG signal rejected - Failed volume condition")
                    
                # Check and log short entry conditions
                meets_short_threshold = short_price_change_pct < -thresholds['short_entry']
                meets_short_trend = trend_conditions.get('short', False)
                
                if not meets_short_threshold:
                    logger.info(f"SHORT signal rejected - Price change {short_price_change_pct*100:.4f}% not below threshold -{thresholds['short_entry']*100:.4f}%")
                elif not meets_short_trend:
                    logger.info(f"SHORT signal rejected - Failed trend condition")
                elif not volume_ok:
                    logger.info(f"SHORT signal rejected - Failed volume condition")
                
                # Long entry
                if meets_long_threshold and meets_long_trend and volume_ok:
                    self.position = 'long'
                    self.entry_price = current_price
                    self.entry_time = timestamp
                    signal['action'] = 'enter_long'
                    signal['position'] = 'long'
                    signal['position_size'] = self.params['position_size'] * confidence.get('up', 0.5) * 2
                    self.daily_trades += 1
                    
                    # Safe logging with try/except
                    try:
                        logger.info(f"LONG SIGNAL: {ticker} at {timestamp}, price={current_price:.2f}, " +
                            f"predicted_change={long_price_change_pct*100:.2f}%, " +
                            f"threshold={thresholds['long_entry']*100:.2f}%")
                    except (TypeError, ValueError) as e:
                        logger.info(f"LONG SIGNAL: {ticker} at {timestamp}, price={current_price}, " +
                            f"predicted_change={long_price_change_pct*100}%, " +
                            f"threshold={thresholds['long_entry']*100}%")
                
                # Short entry - using original prediction without bias correction
                elif meets_short_threshold and meets_short_trend and volume_ok:
                    self.position = 'short'
                    self.entry_price = current_price
                    self.entry_time = timestamp
                    signal['action'] = 'enter_short'
                    signal['position'] = 'short'
                    signal['position_size'] = self.params['position_size'] * confidence.get('down', 0.5) * 2
                    self.daily_trades += 1
                    
                    # Safe logging with try/except
                    try:
                        logger.info(f"SHORT SIGNAL: {ticker} at {timestamp}, price={current_price:.2f}, " +
                            f"predicted_change={short_price_change_pct*100:.2f}%, " +
                            f"threshold={thresholds['short_entry']*100:.2f}%")
                    except (TypeError, ValueError) as e:
                        logger.info(f"SHORT SIGNAL: {ticker} at {timestamp}, price={current_price}, " +
                            f"predicted_change={short_price_change_pct*100}%, " +
                            f"threshold={thresholds['short_entry']*100}%")
                    
                # Enhanced criteria for short entries based on technical indicators
                elif ('RSI' in market_data.columns and 'MACD_Histogram' in market_data.columns):
                    rsi_value = market_data['RSI'].iloc[-1]
                    macd_hist = market_data['MACD_Histogram'].iloc[-1]
                    
                    # FIXED: Proper way to handle conditional formatting for logging
                    rsi_str = f"{rsi_value:.2f}" if pd.notna(rsi_value) else "N/A"
                    macd_str = f"{macd_hist:.6f}" if pd.notna(macd_hist) else "N/A"
                    logger.info(f"RSI: {rsi_str}, MACD Histogram: {macd_str}")
                    
                    if (pd.notna(rsi_value) and pd.notna(macd_hist) and 
                        rsi_value > 65 and macd_hist < 0 and volume_ok and trend_conditions.get('short', False)):
                        # RSI approaching overbought and MACD bearish - potential short opportunity
                        self.position = 'short'
                        self.entry_price = current_price
                        self.entry_time = timestamp
                        signal['action'] = 'enter_short'
                        signal['position'] = 'short'
                        signal['position_size'] = self.params['position_size'] * 1.0
                        signal['forced_signal'] = True
                        self.daily_trades += 1
                        
                        # Safe logging with try/except
                        try:
                            logger.info(f"FORCED SHORT SIGNAL: {ticker} at {timestamp}, price={current_price:.2f}, " +
                                f"RSI={rsi_value:.1f}, MACD_Hist={macd_hist:.4f}")
                        except (TypeError, ValueError) as e:
                            logger.info(f"FORCED SHORT SIGNAL: {ticker} at {timestamp}, price={current_price}, " +
                                f"RSI={rsi_value}, MACD_Hist={macd_hist}")
                    else:
                        logger.info(f"Technical indicator SHORT condition not met - need RSI>65 and MACD_Hist<0")
                    
                    if (pd.notna(rsi_value) and pd.notna(macd_hist) and 
                        rsi_value < 35 and macd_hist > 0 and volume_ok and trend_conditions.get('long', False)):
                        # RSI approaching oversold and MACD bullish - potential long opportunity
                        self.position = 'long'
                        self.entry_price = current_price
                        self.entry_time = timestamp
                        signal['action'] = 'enter_long'
                        signal['position'] = 'long'
                        signal['position_size'] = self.params['position_size'] * 1.0
                        signal['forced_signal'] = True
                        self.daily_trades += 1
                        
                        # Safe logging with try/except
                        try:
                            logger.info(f"FORCED LONG SIGNAL: {ticker} at {timestamp}, price={current_price:.2f}, " +
                                f"RSI={rsi_value:.1f}, MACD_Hist={macd_hist:.4f}")
                        except (TypeError, ValueError) as e:
                            logger.info(f"FORCED LONG SIGNAL: {ticker} at {timestamp}, price={current_price}, " +
                                f"RSI={rsi_value}, MACD_Hist={macd_hist}")
                    else:
                        logger.info(f"Technical indicator LONG condition not met - need RSI<35 and MACD_Hist>0")
                
                # Log if no signals were generated
                if not signal['action']:
                    logger.info(f"No trading signal generated - conditions not met")
            
            # Handle exit signals
            else:
                if self.position == 'long':
                    # Store entry_price before it gets reset
                    stored_entry_price = self.entry_price
                    
                    # Calculate price change safely
                    if stored_entry_price is not None:
                        price_change_from_entry = (current_price - stored_entry_price) / stored_entry_price
                        logger.info(f"LONG position - Price change from entry: {price_change_from_entry*100:.4f}%")
                        logger.info(f"Exit thresholds - Stop loss: -{thresholds['stop_loss']*100:.4f}%, Take profit: {thresholds['exit']*100:.4f}%")
                        
                        if price_change_from_entry <= -thresholds['stop_loss']:
                            logger.info(f"LONG position - Stop loss triggered")
                            # Set signal data before resetting position variables
                            signal['action'] = 'exit_long'
                            signal['position'] = 'long'
                            signal['exit_reason'] = 'stop_loss'
                            
                            # Safe logging with try/except - capture values before reset
                            try:
                                logger.info(f"EXIT LONG (STOP_LOSS): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price:.2f}, exit={current_price:.2f}, " +
                                    f"profit={price_change_from_entry*100:.2f}%")
                            except (TypeError, ValueError) as e:
                                logger.info(f"EXIT LONG (STOP_LOSS): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price}, exit={current_price}, " +
                                    f"profit={price_change_from_entry*100}%")
                            
                            # Reset position variables after logging
                            self.position = None
                            self.entry_price = None
                            self.entry_time = None
                            
                        elif price_change_from_entry >= thresholds['exit']:
                            logger.info(f"LONG position - Take profit triggered")
                            # Set signal data before resetting position variables
                            signal['action'] = 'exit_long'
                            signal['position'] = 'long'
                            signal['exit_reason'] = 'take_profit'
                            
                            # Safe logging with try/except - capture values before reset
                            try:
                                logger.info(f"EXIT LONG (TAKE_PROFIT): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price:.2f}, exit={current_price:.2f}, " +
                                    f"profit={price_change_from_entry*100:.2f}%")
                            except (TypeError, ValueError) as e:
                                logger.info(f"EXIT LONG (TAKE_PROFIT): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price}, exit={current_price}, " +
                                    f"profit={price_change_from_entry*100}%")
                            
                            # Reset position variables after logging
                            self.position = None
                            self.entry_price = None
                            self.entry_time = None
                            
                        else:
                            logger.info(f"LONG position - Holding (no exit signal)")
                    else:
                        logger.warning(f"LONG position has invalid entry price. Resetting position.")
                        self.position = None
                        self.entry_price = None
                        self.entry_time = None
                
                elif self.position == 'short':
                    # Store entry_price before it gets reset
                    stored_entry_price = self.entry_price
                    
                    # Calculate price change safely
                    if stored_entry_price is not None:
                        price_change_from_entry = (stored_entry_price - current_price) / stored_entry_price
                        logger.info(f"SHORT position - Price change from entry: {price_change_from_entry*100:.4f}%")
                        logger.info(f"Exit thresholds - Stop loss: -{thresholds['stop_loss']*100:.4f}%, Take profit: {thresholds['exit']*100:.4f}%")
                        
                        if price_change_from_entry <= -thresholds['stop_loss']:
                            logger.info(f"SHORT position - Stop loss triggered")
                            # Set signal data before resetting position variables
                            signal['action'] = 'exit_short'
                            signal['position'] = 'short'
                            signal['exit_reason'] = 'stop_loss'
                            
                            # Safe logging with try/except - capture values before reset
                            try:
                                logger.info(f"EXIT SHORT (STOP_LOSS): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price:.2f}, exit={current_price:.2f}, " +
                                    f"profit={price_change_from_entry*100:.2f}%")
                            except (TypeError, ValueError) as e:
                                logger.info(f"EXIT SHORT (STOP_LOSS): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price}, exit={current_price}, " +
                                    f"profit={price_change_from_entry*100}%")
                            
                            # Reset position variables after logging
                            self.position = None
                            self.entry_price = None
                            self.entry_time = None
                            
                        elif price_change_from_entry >= thresholds['exit']:
                            logger.info(f"SHORT position - Take profit triggered")
                            # Set signal data before resetting position variables
                            signal['action'] = 'exit_short'
                            signal['position'] = 'short'
                            signal['exit_reason'] = 'take_profit'
                            
                            # Safe logging with try/except - capture values before reset
                            try:
                                logger.info(f"EXIT SHORT (TAKE_PROFIT): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price:.2f}, exit={current_price:.2f}, " +
                                    f"profit={price_change_from_entry*100:.2f}%")
                            except (TypeError, ValueError) as e:
                                logger.info(f"EXIT SHORT (TAKE_PROFIT): {ticker} at {timestamp}, " +
                                    f"entry={stored_entry_price}, exit={current_price}, " +
                                    f"profit={price_change_from_entry*100}%")
                            
                            # Reset position variables after logging
                            self.position = None
                            self.entry_price = None
                            self.entry_time = None
                            
                        else:
                            logger.info(f"SHORT position - Holding (no exit signal)")
                    else:
                        logger.warning(f"SHORT position has invalid entry price. Resetting position.")
                        self.position = None
                        self.entry_price = None
                        self.entry_time = None
                        
        except Exception as e:
            logger.error(f"Error generating signals: {str(e)}")
            traceback.print_exc()  # Added stack trace for better debugging
            return self._create_signal(ticker, timestamp, current_price)
        
        return signal
        


class PolygonDataFetcher:
    """Handles data retrieval from the Polygon.io API with robust error handling and rate limiting."""
    
    BASE_URL = "https://api.polygon.io/v2"

    def __init__(self, api_key: str):
        self.api_key = api_key
        self.session = requests.Session()
        self.session.headers.update({"Authorization": f"Bearer {api_key}"})
        self.validation_metrics = []

    def _make_request(self, endpoint: str, params: Optional[Dict] = None, retries: int = 3) -> Dict:
        """Make an API request with retry logic and rate limiting."""
        url = f"{self.BASE_URL}/{endpoint}"
        
        for attempt in range(retries):
            try:
                response = self.session.get(url, params=params, timeout=30)
                
                if response.status_code == 429:  # Rate limit exceeded
                    wait_time = min(60 * (attempt + 1), 300)  # Max 5 minutes
                    logger.warning(f"Rate limit exceeded. Waiting {wait_time} seconds...")
                    time.sleep(wait_time)
                    continue
                    
                response.raise_for_status()
                data = response.json()
                
                return data
                
            except Exception as e:
                if attempt == retries - 1:
                    logger.error(f"Request failed after {retries} attempts: {e}")
                    raise
                logger.warning(f"Request failed (Attempt {attempt + 1}/{retries}): {e}. Retrying in 30 seconds...")
                time.sleep(30)
                
        raise Exception(f"Failed after {retries} attempts")

    def fetch_stock_data(self, symbol: str, start_date: str, end_date: str, 
                        timespan: str = "minute", multiplier: int = 1) -> pd.DataFrame:
        """
        Fetch aggregate stock data for a specific ticker and date range.
        
        Args:
            symbol: Stock ticker symbol
            start_date: Start date in YYYY-MM-DD format
            end_date: End date in YYYY-MM-DD format
            timespan: Timespan of the aggregates ('minute', 'hour', 'day')
            multiplier: The size of the timespan multiplier
            
        Returns:
            DataFrame containing aggregated stock data
        """
        endpoint = f"aggs/ticker/{symbol}/range/{multiplier}/{timespan}/{start_date}/{end_date}"
        params = {"adjusted": "true", "sort": "asc", "limit": 50000}
        all_results = []

        with tqdm(desc=f"Fetching {symbol}", unit='rows') as pbar:
            while True:
                data = self._make_request(endpoint, params)
                
                if 'results' in data:
                    if len(data['results']) > 0 and 'c' not in data['results'][0]:
                        logger.error(f"Data for {symbol} does not contain 'c' (close) field.")
                        return pd.DataFrame()

                    all_results.extend(data['results'])
                    pbar.update(len(data['results']))
                    
                if len(data.get('results', [])) < 50000:
                    break
                    
                # Update start_date for next request
                last_timestamp = data['results'][-1]['t']
                start_date = datetime.fromtimestamp(last_timestamp / 1000).strftime('%Y-%m-%d')
                params['from'] = start_date

        # Create DataFrame and process data
        if not all_results:
            logger.warning(f"No data retrieved for {symbol}")
            return pd.DataFrame()
            
        df = pd.DataFrame(all_results)
        
        if 't' in df.columns:
            # Remove duplicate timestamps
            duplicates_before = len(df) - df.drop_duplicates(subset=['t'], keep='last').shape[0]
            if duplicates_before > 0:
                logger.info(f"Removed {duplicates_before} duplicate timestamps for {symbol}.")
            df = df.drop_duplicates(subset=['t'], keep='last')
        else:
            logger.error("'t' column not found in fetched data.")
            return pd.DataFrame()

        # Rename columns for clarity
        df = df.rename(columns={
            't': 'timestamp',
            'o': 'open',
            'h': 'high',
            'l': 'low',
            'c': 'close',
            'v': 'volume',
            'n': 'transactions'
        })

        # Convert timestamp to datetime and set as index
        df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
        df.set_index('timestamp', inplace=True)

        return df
        
    def fetch_market_index_data(self, index_symbol: str = "SPY", 
                             start_date: str = None, end_date: str = None) -> pd.DataFrame:
        """Fetch market index data to add market context."""
        if not start_date or not end_date:
            end_date = datetime.now().strftime('%Y-%m-%d')
            start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
            
        return self.fetch_stock_data(index_symbol, start_date, end_date)

def enhance_features(ticker_df: pd.DataFrame, market_index_df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
    """
    Add enhanced features including market regime detection.
    
    Args:
        ticker_df: DataFrame with ticker OHLCV data
        market_index_df: Optional DataFrame with market index data for context
    
    Returns:
        DataFrame with additional features and indicators
    """
    # Check for required columns
    required_cols = ['open', 'high', 'low', 'close', 'volume']
    if not all(col in ticker_df.columns for col in required_cols):
        missing = [col for col in required_cols if col not in ticker_df.columns]
        raise ValueError(f"Missing required columns: {missing}")
        
    enhanced_df = ticker_df.copy()
    
    # Add basic technical indicators
    enhanced_df['returns'] = enhanced_df['close'].pct_change()
    enhanced_df['log_returns'] = np.log(enhanced_df['close'] / enhanced_df['close'].shift(1))
    
    # Price moving averages
    enhanced_df['MA5'] = enhanced_df['close'].rolling(window=5).mean()
    enhanced_df['SMA_20'] = enhanced_df['close'].rolling(window=20).mean()
    enhanced_df['EMA_20'] = enhanced_df['close'].ewm(span=20).mean()
    
    # Momentum indicators
    enhanced_df['RSI'] = ta.momentum.rsi(enhanced_df['close'], window=14)
    
    # Volatility indicators
    bb_indicator = BollingerBands(close=enhanced_df['close'])
    enhanced_df['Bollinger_High'] = bb_indicator.bollinger_hband()
    enhanced_df['Bollinger_Low'] = bb_indicator.bollinger_lband()
    enhanced_df['Bollinger_Width'] = (enhanced_df['Bollinger_High'] - enhanced_df['Bollinger_Low']) / enhanced_df['SMA_20']
    
    # Trend indicators
    macd_indicator = MACD(close=enhanced_df['close'])
    enhanced_df['MACD'] = macd_indicator.macd()
    enhanced_df['MACD_Signal'] = macd_indicator.macd_signal()
    enhanced_df['MACD_Histogram'] = enhanced_df['MACD'] - enhanced_df['MACD_Signal']
    
    # Price patterns
    enhanced_df['daily_range'] = enhanced_df['high'] - enhanced_df['low']
    enhanced_df['gap_up'] = (enhanced_df['open'] - enhanced_df['close'].shift(1)) / enhanced_df['close'].shift(1)
    enhanced_df['body_size'] = abs(enhanced_df['close'] - enhanced_df['open']) / enhanced_df['open']
    
    # Volume analysis
    enhanced_df['volume_ma'] = enhanced_df['volume'].rolling(window=20).mean()
    enhanced_df['volume_ratio'] = enhanced_df['volume'] / enhanced_df['volume_ma']
    enhanced_df['price_volume'] = enhanced_df['returns'].abs() * enhanced_df['volume_ratio']

    # Add relative features rather than absolute values
    enhanced_df['price_acceleration'] = enhanced_df['returns'].diff()
    enhanced_df['ma_cross'] = (enhanced_df['MA5'] > enhanced_df['SMA_20']).astype(int)
    enhanced_df['bb_position'] = (enhanced_df['close'] - enhanced_df['Bollinger_Low']) / (enhanced_df['Bollinger_High'] - enhanced_df['Bollinger_Low'])
    
    # Market regime detection
    enhanced_df['volatility'] = enhanced_df['returns'].rolling(20).std() * np.sqrt(252)  # Annualized
    enhanced_df['volatility_percentile'] = enhanced_df['volatility'].rolling(60).apply(
        lambda x: percentileofscore(x, x.iloc[-1]) if len(x) > 0 else 50
    )
    
    # Trend regime
    enhanced_df['trend_strength'] = (enhanced_df['MA5'] - enhanced_df['SMA_20']) / enhanced_df['SMA_20']
    enhanced_df['trend_regime'] = pd.cut(
        enhanced_df['trend_strength'].fillna(0),
        bins=[-float('inf'), -0.02, -0.005, 0.005, 0.02, float('inf')],
        labels=['strong_downtrend', 'downtrend', 'neutral', 'uptrend', 'strong_uptrend']
    )
    
    # Up/down move features
    enhanced_df['up_day'] = (enhanced_df['close'] > enhanced_df['close'].shift(1)).astype(int)
    enhanced_df['down_day'] = (enhanced_df['close'] < enhanced_df['close'].shift(1)).astype(int)
    enhanced_df['up_volume'] = enhanced_df['volume'] * enhanced_df['up_day']
    enhanced_df['down_volume'] = enhanced_df['volume'] * enhanced_df['down_day']

    # Enhanced price distance features
    enhanced_df['price_dist_from_mean_5d'] = (enhanced_df['close'] - enhanced_df['MA5']) / enhanced_df['close']
    enhanced_df['price_dist_from_mean_20d'] = (enhanced_df['close'] - enhanced_df['SMA_20']) / enhanced_df['close']
    
    # Enhanced momentum features
    enhanced_df['price_momentum_5d'] = enhanced_df['returns'].rolling(5).sum()
    enhanced_df['price_momentum_20d'] = enhanced_df['returns'].rolling(20).sum()
    
    # Non-linear transformations
    enhanced_df['log_abs_return'] = np.log(np.abs(enhanced_df['returns']) + 1e-6)
    enhanced_df['return_sign'] = np.sign(enhanced_df['returns'])
    enhanced_df['return_squared'] = enhanced_df['returns'] ** 2
    
    # Range features
    enhanced_df['daily_range_ratio'] = enhanced_df['daily_range'] / enhanced_df['close']
    enhanced_df['range_momentum'] = enhanced_df['daily_range_ratio'].rolling(5).mean()
    
    # Add market index context if available
    if market_index_df is not None and not market_index_df.empty:
        # Align indices
        market_index_df = market_index_df.reindex(enhanced_df.index, method='ffill')
        
        # Add market return features
        enhanced_df['market_return'] = market_index_df['close'].pct_change()
        enhanced_df['market_ma20'] = market_index_df['close'].rolling(20).mean()
        enhanced_df['market_vol'] = market_index_df['close'].pct_change().rolling(20).std()
        
        # Correlation metrics
        enhanced_df['market_correlation'] = enhanced_df['returns'].rolling(20).corr(enhanced_df['market_return'])
        enhanced_df['beta'] = enhanced_df['returns'].rolling(20).cov(enhanced_df['market_return']) / enhanced_df['market_return'].rolling(20).var()
        
        # Relative strength
        enhanced_df['relative_strength'] = enhanced_df['close'] / enhanced_df['close'].shift(20)
        enhanced_df['market_relative_strength'] = market_index_df['close'] / market_index_df['close'].shift(20)
        enhanced_df['rs_ratio'] = enhanced_df['relative_strength'] / enhanced_df['market_relative_strength']
    
    # Drop NaN values created by rolling windows
    enhanced_df.dropna(inplace=True)
    
    return enhanced_df

def create_stratified_dataset(data_df: pd.DataFrame, min_regime_count: int = 50, 
                             min_total_samples: int = 1000,
                             max_sampling_fraction: float = 0.7,
                             max_oversample_factor: float = 5.0) -> pd.DataFrame:
    """
    Create a dataset with balanced market regimes for more even training.
    
    Args:
        data_df: DataFrame with OHLCV data and calculated features
        min_regime_count: Minimum number of samples per regime
        min_total_samples: Minimum total samples in the resulting dataset
        max_sampling_fraction: Maximum reduction as a fraction of original size
        max_oversample_factor: Maximum factor for oversampling any regime
    
    Returns:
        Balanced DataFrame with representation from all market regimes
    """
    # Define market regimes if not already done
    if 'regime' not in data_df.columns:
        # Initialize regime column
        data_df['regime'] = 'neutral'
        
        # Identify high volatility periods (more sensitive thresholds)
        volatility_threshold = data_df['volatility'].quantile(0.65)
        high_vol = data_df['volatility'] >= volatility_threshold
        
        # Identify up and down days (more sensitive thresholds)
        up_day = data_df['returns'] > 0.0008
        down_day = data_df['returns'] < -0.0008
        
        # Combine into regimes
        data_df.loc[high_vol & up_day, 'regime'] = 'high_vol_up'
        data_df.loc[high_vol & down_day, 'regime'] = 'high_vol_down'
        data_df.loc[(~high_vol) & up_day, 'regime'] = 'low_vol_up'
        data_df.loc[(~high_vol) & down_day, 'regime'] = 'low_vol_down'
    
    # Get counts by regime
    regime_counts = data_df['regime'].value_counts()
    logger.info(f"Original regime distribution: {regime_counts.to_dict()}")
    
    # Check if we have sufficient data in any regime
    non_zero_counts = regime_counts[regime_counts > 0]
    if len(non_zero_counts) == 0:
        logger.warning("No valid regimes found in dataset")
        return data_df
        
    # If we have very few samples in any regime, consider all data as neutral
    if non_zero_counts.min() < 10:
        logger.warning(f"Very few samples in some regimes. Classifying all as neutral.")
        data_df['regime'] = 'neutral'
        regime_counts = data_df['regime'].value_counts()
        non_zero_counts = regime_counts
    
    # Find minimum count, but ensure it's at least min_regime_count
    min_count = max(int(non_zero_counts.min() * 0.8), min_regime_count)
    
    # Calculate total samples after balancing
    num_regimes = len(data_df['regime'].unique())
    total_balanced_samples = min_count * num_regimes
    
    # Adjust if total samples would be too small compared to original
    if total_balanced_samples < min_total_samples or total_balanced_samples < max_sampling_fraction * len(data_df):
        # Increase min_count to meet the minimum total requirement
        required_count = max(
            min_total_samples // num_regimes,
            int(max_sampling_fraction * len(data_df) // num_regimes)
        )
        min_count = max(min_count, required_count)
        logger.info(f"Adjusted regime count to {min_count} to maintain sufficient data volume")
    
    # Sample from each regime
    balanced_df = pd.DataFrame()
    for regime in data_df['regime'].unique():
        regime_data = data_df[data_df['regime'] == regime]
        if len(regime_data) == 0:
            continue
            
        if len(regime_data) > min_count:
            # Sample without replacement if enough data
            sampled_data = regime_data.sample(min_count)
            logger.info(f"Sampled without replacement for regime '{regime}': {len(regime_data)} → {min_count}")
        else:
            # Check if we would exceed maximum oversampling factor
            if len(regime_data) * max_oversample_factor < min_count:
                # Cap the number of samples based on max oversampling factor
                actual_samples = int(len(regime_data) * max_oversample_factor)
                logger.warning(f"Limiting oversampling for regime '{regime}': {len(regime_data)} → {actual_samples} (capped at {max_oversample_factor}x)")
                sampled_data = regime_data.sample(actual_samples, replace=True)
            else:
                # Sample with replacement up to min_count
                sampled_data = regime_data.sample(min_count, replace=True)
                logger.info(f"Sampled with replacement for regime '{regime}': {len(regime_data)} → {min_count}")
            
        balanced_df = pd.concat([balanced_df, sampled_data])
    
    # Shuffle the final dataset
    balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)
    
    # Log final distribution
    final_regime_counts = balanced_df['regime'].value_counts()
    logger.info(f"Balanced regime distribution: {final_regime_counts.to_dict()}")
    logger.info(f"Original dataset: {len(data_df)} samples → Balanced dataset: {len(balanced_df)} samples")
    
    return balanced_df

class DirectionalPredictionLoss(nn.Module):
    def __init__(self, direction_weight=2.5, magnitude_weight=0.8, 
                short_penalty_multiplier=1.5, bias_correction_weight=0.8):  # Reduced from 3.5 to 1.5 and from 1.5 to 0.8
        super(DirectionalPredictionLoss, self).__init__()
        self.direction_weight = direction_weight
        self.magnitude_weight = magnitude_weight
        self.short_penalty_multiplier = short_penalty_multiplier
        self.bias_correction_weight = bias_correction_weight
        
        # Ticker-specific adaptations - balance up/down weights, with more equal weighting
        self.ticker_adapters = {
            'MSFT': {'up_weight': 1.0, 'down_weight': 1.2},    # More balanced (was 1.0, 1.5)
            'GOOGL': {'up_weight': 1.0, 'down_weight': 1.2},   # More balanced (was 1.0, 1.5)
            'TSLA': {'up_weight': 1.0, 'down_weight': 1.3},    # More balanced (was 1.0, 1.6)
            'NVDA': {'up_weight': 1.0, 'down_weight': 1.2},    # More balanced (was 1.0, 1.5)
            'TQQQ': {'up_weight': 1.0, 'down_weight': 1.3},    # More balanced (was 0.9, 1.6)
            'QQQ': {'up_weight': 1.0, 'down_weight': 1.2}      # More balanced (was 0.9, 1.4)
        }
        
        # Default adapter for tickers not explicitly listed - more balanced
        self.default_adapter = {'up_weight': 1.0, 'down_weight': 1.2}  # Changed from 0.9/1.4 to 1.0/1.2
        
        self.current_ticker = None
        
        # For tracking error bias
        self.prediction_history = []
        self.direction_bias_history = []
        self.max_history_length = 50
        
    def set_ticker(self, ticker: str):
        """Set the current ticker for adaptive weights."""
        self.current_ticker = ticker
        self.prediction_history = []
        self.direction_bias_history = []
        
    def forward(self, y_pred, y_true, current_prices=None):
        # Use current price if provided, otherwise use previous price
        if current_prices is None:
            current_prices = y_true.roll(1)
            current_prices[0] = y_true[0]  # Avoid NaN in first element
        
        # Calculate price movements
        true_movement = y_true - current_prices
        pred_movement = y_pred - current_prices
        
        # Direction of movement
        true_direction = torch.sign(true_movement)
        pred_direction = torch.sign(pred_movement)
        
        # Direction error (0 if correct, 1 if wrong)
        direction_error = (pred_direction != true_direction).float()
        
        # Magnitude error
        magnitude_error = torch.abs(y_pred - y_true)
        
        # Store current prediction error for bias tracking
        with torch.no_grad():
            current_error = (y_pred - y_true).mean().item()
            self.prediction_history.append(current_error)
            
            # Track directional bias (positive means more up predictions)
            direction_bias = ((pred_direction > 0).float().mean() - 0.5) * 2  # Scale -1 to 1
            self.direction_bias_history.append(direction_bias.item())
            
            # Maintain max history length
            if len(self.prediction_history) > self.max_history_length:
                self.prediction_history.pop(0)
                self.direction_bias_history.pop(0)
        
        # Apply reduced bias correction with more weight on recent errors
        if len(self.prediction_history) > 10:
            # Use exponentially weighted bias correction - more weight to recent errors
            weights = torch.tensor([0.8 ** i for i in range(min(10, len(self.prediction_history)))], device=y_pred.device)
            weights = weights / weights.sum()
            
            recent_errors = torch.tensor(self.prediction_history[-10:], device=y_pred.device)
            weighted_bias = (recent_errors * weights).sum()
            
            # Reduced bias correction weight
            bias_correction = torch.abs(weighted_bias) * self.bias_correction_weight
            
            # Add directional bias penalty - reduced penalty
            direction_bias = torch.tensor(np.mean(self.direction_bias_history[-10:]), device=y_pred.device) 
            if direction_bias > 0:  # If biased toward up predictions
                bias_correction += direction_bias * 0.5  # Reduced from 0.8 for less aggressive penalty
        else:
            bias_correction = torch.tensor(0.0, device=y_pred.device)
        
        # Apply ticker-specific direction weights
        if self.current_ticker and self.current_ticker in self.ticker_adapters:
            adapter = self.ticker_adapters[self.current_ticker]
        else:
            adapter = self.default_adapter
            
        # Create separate masks for up and down movements
        up_mask = (true_direction > 0)
        down_mask = (true_direction < 0)
        
        # Apply different weights based on direction - with more balanced weights
        direction_error_weighted = direction_error.clone()
        direction_error_weighted[up_mask] *= adapter['up_weight']
        direction_error_weighted[down_mask] *= adapter['down_weight']  # Reduced penalty
        
        # Replace original with weighted version
        direction_error = direction_error_weighted
            
        # Add reduced extra penalty for missing downward movements
        down_movement_mask = (true_direction < 0)
        direction_error[down_movement_mask] *= self.short_penalty_multiplier  # Reduced from 3.5 to 1.5
        
        # Combined loss with bias correction
        loss = (self.direction_weight * direction_error + 
                self.magnitude_weight * magnitude_error + 
                bias_correction)
        
        # Add directional balance constraint with updated target
        with torch.no_grad():
            pred_directions = torch.sign(y_pred - current_prices)
            up_ratio = (pred_directions > 0).float().mean()
            
            # Target a more balanced range (0.40-0.60)
            ideal_up_ratio = 0.45  # Adjusted toward more balanced (was 0.40)
            distribution_penalty = 3.0 * torch.abs(up_ratio - ideal_up_ratio)  # Reduced from 4.0 to 3.0
        
        # Add the distribution penalty to the loss
        loss = loss + distribution_penalty

        return loss.mean()

class HFTDataset(Dataset):
    """Enhanced Dataset for High-Frequency Trading data preparation."""

    def __init__(self, data: pd.DataFrame, sequence_length: int, 
                target_column: str = 'close',
                scaler: Optional[StandardScaler] = None,
                include_current_price: bool = True,
                relative_normalization: bool = True,
                ticker: str = None):
        """Initialize the dataset with enhanced features and sequence normalization."""
        self.sequence_length = sequence_length
        self.include_current_price = include_current_price
        self.scaler = scaler or StandardScaler()
        self.relative_normalization = relative_normalization
        self.ticker = ticker
        
        # Define required feature columns
        feature_cols = [
            'close', 'volume', 'SMA_20', 'EMA_20', 'RSI', 'MA5',
            'Bollinger_High', 'Bollinger_Low', 'MACD', 'MACD_Signal'
        ]
        
        # Add enhanced features if available
        enhanced_features = [
            'volatility', 'trend_strength', 'body_size', 'volume_ratio', 
            'price_volume', 'price_acceleration', 'price_dist_from_mean_5d',
            'price_dist_from_mean_20d', 'price_momentum_5d', 'price_momentum_20d',
            'log_abs_return', 'daily_range_ratio', 'range_momentum'
        ]
        
        for feat in enhanced_features:
            if feat in data.columns:
                feature_cols.append(feat)
                
        # Add market features if available
        market_features = ['market_return', 'market_vol', 'beta', 'rs_ratio']
        for feat in market_features:
            if feat in data.columns:
                feature_cols.append(feat)

        # Validate feature columns
        missing_cols = [col for col in feature_cols if col not in data.columns]
        if missing_cols:
            logger.warning(f"Missing feature columns: {missing_cols}")
            feature_cols = [col for col in feature_cols if col in data.columns]
            
        # Prepare features and labels
        self.features = data[feature_cols].values
        self.target_col_idx = data.columns.get_loc(target_column)
        
        # NEW: Create more balanced target by using both next price and percent change
        self.labels = data[target_column].shift(-1).values  # Next price
        
        # Store current prices if needed
        if include_current_price:
            self.current_prices = data[target_column].values

        # Remove last row (contains NaN label)
        self.features = self.features[:-1]
        self.labels = self.labels[:-1]
        if include_current_price:
            self.current_prices = self.current_prices[:-1]

        # Apply standard scaling with bias correction for up/down balance
        self.features = self.scaler.fit_transform(self.features)
        
        # NEW: Apply symmetric normalization for price-related features
        price_col_indices = [i for i, name in enumerate(feature_cols) 
                             if 'close' in name or 'price' in name or 'high' in name or 'low' in name]
        
        if price_col_indices and relative_normalization:
            # Calculate relative changes for price features
            for idx in price_col_indices:
                price_series = self.features[:, idx]
                price_mean = np.mean(price_series)
                price_std = np.std(price_series)
                
                # Center around mean with normalized variance
                if price_std > 0:
                    self.features[:, idx] = (price_series - price_mean) / price_std
        
        # Store the original scaled features for sequence creation
        self.scaled_features = self.features.copy()
        
        # Store column names for reference
        self.feature_names = feature_cols

    def __len__(self) -> int:
        """Return the total number of sequences in the dataset."""
        return max(0, len(self.labels) - self.sequence_length)

    def __getitem__(self, idx: int) -> Tuple:
        """Get a single sequence and its corresponding label with relative normalization."""
        # Extract the sequence
        seq = self.scaled_features[idx:idx + self.sequence_length].copy()
        
        # Apply relative normalization if enabled
        if self.relative_normalization:
            # NEW: Apply symmetrical normalization centered around zero
            # This helps ensure up/down movements are treated equally
            first_values = seq[0].copy().reshape(1, -1)
            epsilon = 1e-8
            
            # Calculate relative changes from first value (symmetrical around 0)
            seq = (seq - first_values) / (np.abs(first_values) + epsilon)
        
        y = self.labels[idx + self.sequence_length]
        
        if self.include_current_price:
            current_price = self.current_prices[idx + self.sequence_length]
            return (
                torch.tensor(seq, dtype=torch.float32),
                torch.tensor(y, dtype=torch.float32),
                torch.tensor(current_price, dtype=torch.float32)
            )
        else:
            return (
                torch.tensor(seq, dtype=torch.float32),
                torch.tensor(y, dtype=torch.float32)
            )

class LSTMModel(nn.Module):
    """Enhanced LSTM model for time series prediction with attention mechanism and stronger regularization."""

    def __init__(self, input_size: int, hidden_size: int, num_layers: int, 
                output_size: int = 1, dropout: float = 0.5):  # Increased dropout from 0.35 to 0.5
        """
        Initialize an LSTM model with dropout for regularization.
        
        Args:
            input_size: Number of input features
            hidden_size: Size of hidden state in LSTM
            num_layers: Number of LSTM layers
            output_size: Size of output
            dropout: Dropout probability for regularization
        """
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM layers with dropout
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True  # Use bidirectional LSTM for better context
        )
        
        # Adjusted hidden size for bidirectional LSTM
        bidirectional_hidden_size = hidden_size * 2
        
        # Add dropout layer with higher rate
        self.dropout = nn.Dropout(dropout)
        
        # Attention mechanism with stronger regularization
        self.attention = nn.Sequential(
            nn.Linear(bidirectional_hidden_size, bidirectional_hidden_size // 2),
            nn.Tanh(),
            nn.Dropout(dropout * 0.5),  # Add dropout in attention mechanism
            nn.Linear(bidirectional_hidden_size // 2, 1)
        )
        
        # Add batch normalization for better training stability
        self.batch_norm1 = nn.BatchNorm1d(bidirectional_hidden_size)
        self.batch_norm2 = nn.BatchNorm1d(bidirectional_hidden_size // 2)
        
        # Price variance prediction for dynamic range
        self.fc_variance = nn.Linear(bidirectional_hidden_size, 1)
        
        # Main prediction branch
        self.fc1 = nn.Linear(bidirectional_hidden_size, bidirectional_hidden_size // 2)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(bidirectional_hidden_size // 2, output_size)
        
        # Output transformation with stronger regularization
        self.output_transform = nn.Sequential(
            nn.Linear(2, 32),  # Increased from 16 to 32 nodes
            nn.ReLU(),
            nn.Dropout(0.3),   # Add dropout in output transformation
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, output_size)
        )
        
        # Weight initialization for better convergence
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        """Initialize weights for faster convergence."""
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LSTM):
            for name, param in module.named_parameters():
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(param)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(param)
                elif 'bias' in name:
                    nn.init.constant_(param, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network."""
        # Initialize hidden state with zeros
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)  # *2 for bidirectional
        c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))
        
        # Apply attention
        attn_weights = F.softmax(self.attention(out).squeeze(-1), dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), out).squeeze(1)
        
        # Apply batch normalization
        if batch_size > 1:  # Batch norm requires more than 1 sample
            context = self.batch_norm1(context)
        
        # Apply dropout to the attention output
        context = self.dropout(context)
        
        # Base prediction
        hidden = self.fc1(context)
        hidden = self.relu(hidden)
        if batch_size > 1:
            hidden = self.batch_norm2(hidden)
        hidden = self.dropout(hidden)  # Additional dropout
        base_pred = self.fc2(hidden)
        
        # Predict variance (confidence)
        pred_variance = torch.exp(self.fc_variance(context))
        
        # Combine for final prediction
        combined = torch.cat((base_pred, pred_variance), dim=1)
        final_output = self.output_transform(combined)
        
        return final_output

class EnsembleModel:
    """
    Ensemble of multiple LSTM models for more robust predictions.
    """
    
    def __init__(self, models: List[nn.Module], weights: Optional[List[float]] = None):
        """
        Initialize ensemble with multiple models and optional weights.
        
        Args:
            models: List of trained PyTorch models
            weights: Optional weights for each model (defaults to equal weights)
        """
        self.models = models
        
        # Default to equal weights if not provided
        if weights is None:
            self.weights = [1.0 / len(models)] * len(models)
        else:
            # Normalize weights to sum to 1
            total = sum(weights)
            self.weights = [w / total for w in weights]
            
        if len(self.weights) != len(self.models):
            raise ValueError("Number of weights must match number of models")
        
    def predict(self, data_loader: DataLoader, device: torch.device) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate ensemble predictions by weighted averaging.
        
        Args:
            data_loader: DataLoader with test data
            device: Computation device (CPU/GPU)
            
        Returns:
            Tuple of true values and ensemble predictions
        """
        all_predictions = []
        true_values = None
        current_prices = None
        
        # Get predictions from each model
        for model in self.models:
            model.eval()
            model = model.to(device)
            
            predictions = []
            y_true_list = []
            current_prices_list = []
            
            with torch.no_grad():
                for batch in data_loader:
                    if len(batch) == 3:  # Includes current prices
                        batch_x, batch_y, batch_current_prices = batch
                        batch_x = batch_x.to(device)
                        outputs = model(batch_x).cpu().numpy()
                        predictions.extend(outputs.flatten())
                        y_true_list.extend(batch_y.numpy())
                        current_prices_list.extend(batch_current_prices.numpy())
                    else:
                        batch_x, batch_y = batch
                        batch_x = batch_x.to(device)
                        outputs = model(batch_x).cpu().numpy()
                        predictions.extend(outputs.flatten())
                        y_true_list.extend(batch_y.numpy())
            
            all_predictions.append(np.array(predictions))
            
            # Store true values and current prices (same for all models)
            if true_values is None:
                true_values = np.array(y_true_list)
            if current_prices is None and current_prices_list:
                current_prices = np.array(current_prices_list)
        
        # Compute weighted ensemble predictions
        ensemble_pred = np.zeros_like(all_predictions[0])
        for i, pred in enumerate(all_predictions):
            ensemble_pred += pred * self.weights[i]
            
        return true_values, ensemble_pred, current_prices
    
    def evaluate(self, data_loader: DataLoader, device: torch.device, 
                evaluator: Optional[ModelEvaluator] = None,
                ticker: str = None) -> Dict[str, float]:
        """
        Evaluate ensemble performance with detailed metrics.
        
        Args:
            data_loader: DataLoader with test data
            device: Computation device (CPU/GPU)
            evaluator: Optional ModelEvaluator for metrics calculation
            ticker: Optional stock ticker symbol
            
        Returns:
            Dictionary of evaluation metrics
        """
        # Get ensemble predictions
        y_true, y_pred, current_prices = self.predict(data_loader, device)
        
        # Calculate metrics
        if evaluator is None:
            evaluator = ModelEvaluator()
            
        # Apply adaptive correction if ticker is provided
        if ticker is not None:
            correction = evaluator.get_adaptive_correction(ticker)
            if correction != 0:
                logger.info(f"Applying adaptive correction of {correction:.4f} for {ticker}")
                y_pred = y_pred - correction
        
        # Apply reasonable range clipping if current prices are available
        if current_prices is not None:
            # Calculate reasonable prediction bounds
            max_pct_change = 0.015  # 1.5% maximum change per minute
            min_bound = current_prices * (1 - max_pct_change)
            max_bound = current_prices * (1 + max_pct_change)
            
            # Clip predictions to reasonable range
            y_pred = np.clip(y_pred, min_bound, max_bound)
            
        # Calculate and return metrics
        metrics = evaluator.calculate_metrics(
            y_true=y_true,
            y_pred=y_pred,
            current_prices=current_prices,
            ticker=ticker
        )
        
        return metrics, y_true, y_pred, current_prices

def train_ensemble_models(
    data_df: pd.DataFrame,
    model_configs: List[Dict],
    device: torch.device,
    sequence_length: int = 60,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    batch_size: int = 64,
    ticker: str = None) -> EnsembleModel:
    """
    Train multiple models with different configurations for ensemble.
    
    Args:
        data_df: DataFrame with features and target
        model_configs: List of model configuration dictionaries
        device: Computation device (CPU/GPU)
        sequence_length: LSTM sequence length
        train_ratio: Ratio of data to use for training
        val_ratio: Ratio of data to use for validation
        batch_size: Batch size for training
        ticker: Optional stock ticker symbol
        
    Returns:
        Trained EnsembleModel object
    """
    # Prepare dataset
    dataset = HFTDataset(
        data=data_df,
        sequence_length=sequence_length,
        include_current_price=True,
        ticker=ticker
    )
    
    # Split data
    train_size = int(train_ratio * len(dataset))
    val_size = int(val_ratio * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Train models
    trained_models = []
    validation_scores = []
    
    for i, config in enumerate(model_configs):
        logger.info(f"Training model {i+1}/{len(model_configs)}")
        logger.info(f"Configuration: {config}")
        
        # Initialize model
        input_size = len(dataset.feature_names)
        model = LSTMModel(
            input_size=input_size,
            hidden_size=config['hidden_size'],
            num_layers=config['num_layers'],
            dropout=config['dropout']
        )
        
        # Initialize training components
        criterion = DirectionalPredictionLoss(
            direction_weight=config.get('direction_weight', 2.5),
            magnitude_weight=config.get('magnitude_weight', 0.8),
            short_penalty_multiplier=config.get('short_penalty_multiplier', 3.5),
            bias_correction_weight=config.get('bias_correction_weight', 1.5)
        )
        
        if ticker:
            criterion.set_ticker(ticker)
        
        optimizer = optim.Adam(
            model.parameters(), 
            lr=config.get('learning_rate', 0.001),
            weight_decay=config.get('weight_decay', 1e-2)
        )
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=config.get('lr_factor', 0.3),
            patience=config.get('lr_patience', 2),
            verbose=True,
            min_lr=config.get('min_lr', 1e-6)
        )
        
        # Train model
        training_history = TrainingHistory()
        model = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=config.get('num_epochs', 100),
            device=device,
            training_history=training_history,
            early_stopping_patience=config.get('early_stopping_patience', 15),
            scheduler=scheduler,
            ticker=ticker
        )
        
        # Evaluate on validation set
        evaluator = ModelEvaluator()
        y_true, y_pred, current_prices = evaluate_model(
            model, 
            val_loader, 
            device,
            return_predictions=True,
            evaluator=evaluator,
            ticker=ticker
        )
        
        # Calculate metrics
        metrics = evaluator.calculate_metrics(
            y_true=y_true,
            y_pred=y_pred,
            current_prices=current_prices,
            ticker=ticker
        )
        
        # Store model and validation score
        trained_models.append(model)
        validation_scores.append(metrics.get('direction_accuracy', 0))
        
        logger.info(f"Model {i+1} validation direction accuracy: {validation_scores[-1]:.2f}%")
    
    # Create ensemble with weights based on validation scores
    # Convert validation scores to weights, giving more weight to better models
    weights = [max(0.1, score - 49) for score in validation_scores]  # Minimum weight of 0.1
    
    ensemble = EnsembleModel(trained_models, weights)
    
    # Evaluate ensemble
    ensemble_metrics, _, _, _ = ensemble.evaluate(test_loader, device, ticker=ticker)
    logger.info(f"Ensemble direction accuracy: {ensemble_metrics.get('direction_accuracy', 0):.2f}%")
    
    return ensemble

class MultiTickerMonitor:
    """Monitor multiple tickers and generate trading signals."""

    def __init__(self, signal_generator: AdaptiveSignalGenerator):
        self.tracked_tickers = {}
        self.signal_generator = signal_generator
        
        # Ticker-specific direction confidence - balance up/down confidence values
        self.ticker_confidence = {
            'MSFT': {'up': 0.75, 'down': 0.75},  # Balanced
            'GOOGL': {'up': 0.75, 'down': 0.75},  # Balanced
            'NVDA': {'up': 0.75, 'down': 0.75},   # Balanced
            'TSLA': {'up': 0.75, 'down': 0.75},   # Balanced
            'TQQQ': {'up': 0.70, 'down': 0.80},   # Slightly favor down predictions for volatile ETF
            'QQQ': {'up': 0.75, 'down': 0.75},    # Balanced
        }
        # Default confidence levels - equal for both directions
        self.default_confidence = {'up': 0.75, 'down': 0.75}

    def add_ticker(self, ticker: str, initial_price: float):
        """Add a new ticker to monitor."""
        self.tracked_tickers[ticker] = {
            'current_price': initial_price,
            'signals': [],
            'last_update': datetime.now()
        }

    def update_ticker(self, ticker: str, current_price: float,
                     predicted_price: float, timestamp: pd.Timestamp,
                     market_data: pd.DataFrame,
                     order_book_data: Dict = None,
                     prediction_metrics: Dict = None) -> Optional[Dict]:
        """Update ticker information and generate signals."""
        if ticker not in self.tracked_tickers:
            self.add_ticker(ticker, current_price)
    
        self.tracked_tickers[ticker]['current_price'] = current_price
        self.tracked_tickers[ticker]['last_update'] = timestamp
        
        # Get prediction confidence for this ticker
        price_change = predicted_price - current_price
        direction = 'up' if price_change > 0 else 'down'
        
        # Use stored confidence or default
        confidence = self.ticker_confidence.get(ticker, self.default_confidence)
        
        # Update confidence if metrics provided
        if prediction_metrics:
            if 'up_direction_accuracy' in prediction_metrics:
                confidence['up'] = 0.5 + (prediction_metrics['up_direction_accuracy'] / 200)
            if 'down_direction_accuracy' in prediction_metrics:
                confidence['down'] = 0.5 + (prediction_metrics['down_direction_accuracy'] / 200)
    
        signal = self.signal_generator.generate_signals(
            ticker=ticker,
            current_price=current_price,
            predicted_price=predicted_price,
            timestamp=timestamp,
            market_data=market_data,
            order_book_data=order_book_data,
            prediction_confidence=confidence
        )
    
        if signal and signal['action']:
            self.tracked_tickers[ticker]['signals'].append(signal)
            return signal
    
        return None

    def get_signals(self, ticker: str) -> List[Dict]:
        """Retrieve signals for a specific ticker."""
        return self.tracked_tickers.get(ticker, {}).get('signals', [])

def apply_dynamic_scaling(model: LSTMModel, train_loader: DataLoader, device: torch.device):
    """Apply dynamic output scaling based on target distribution."""
    price_ranges = []
    with torch.no_grad():
        for batch in train_loader:
            if len(batch) == 3:
                _, batch_y, _ = batch
            else:
                _, batch_y = batch
            price_ranges.append((batch_y.min().item(), batch_y.max().item()))
    
    # Calculate global min and max
    global_min = min([r[0] for r in price_ranges])
    global_max = max([r[1] for r in price_ranges])
    price_range = global_max - global_min
    
    # Set model's output transformation parameters
    if hasattr(model, 'output_transform'):
        # Initialize last layer to produce outputs in appropriate range
        last_layer = model.output_transform[-1]
        with torch.no_grad():
            # Scale the last layer weights to produce wider range
            current_range = last_layer.weight.abs().mean().item()
            target_range = price_range / 10  # Aim for 10% of full range
            scale_factor = target_range / (current_range + 1e-8)
            last_layer.weight.mul_(scale_factor)
    
    logger.info(f"Applied dynamic scaling with price range: {price_range:.2f}, scale factor: {scale_factor:.4f}")

def train_model(model: nn.Module, 
              train_loader: DataLoader, 
              val_loader: DataLoader, 
              criterion: nn.Module, 
              optimizer: torch.optim.Optimizer, 
              num_epochs: int, 
              device: torch.device,
              training_history: Optional[TrainingHistory] = None,
              early_stopping_patience: int = 10,
              scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
              ticker: str = None) -> nn.Module:
    """
    Train the LSTM model with enhanced monitoring and early stopping.
    
    Args:
        model: The LSTM model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: Loss function
        optimizer: Optimization algorithm
        num_epochs: Number of training epochs
        device: Computation device (CPU/GPU)
        training_history: Optional history tracker
        early_stopping_patience: Number of epochs to wait before early stopping
        scheduler: Optional learning rate scheduler
        ticker: Optional ticker symbol for ticker-specific adaptations
        
    Returns:
        Trained model
    """
    model.to(device)
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    
    # Apply dynamic scaling for output range
    apply_dynamic_scaling(model, train_loader, device)
    
    # Set ticker for directional loss if applicable
    if ticker and isinstance(criterion, DirectionalPredictionLoss):
        criterion.set_ticker(ticker)
        logger.info(f"Set ticker '{ticker}' for directional loss")

    # Check if dataset provides current price for directional loss
    sample_batch = next(iter(train_loader))
    has_current_price = len(sample_batch) == 3

    for epoch in range(1, num_epochs + 1):
        # Training phase
        model.train()
        train_losses = []
        direction_accuracies = []
        long_accuracies = []
        short_accuracies = []

        for batch in train_loader:
            if has_current_price:
                batch_x, batch_y, current_prices = batch
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device).unsqueeze(1)
                current_prices = current_prices.to(device).unsqueeze(1)
            else:
                batch_x, batch_y = batch
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device).unsqueeze(1)
                current_prices = None

            optimizer.zero_grad()
            outputs = model(batch_x)
            
            # Calculate loss using current prices if available
            if has_current_price and isinstance(criterion, DirectionalPredictionLoss):
                loss = criterion(outputs, batch_y, current_prices)
            else:
                loss = criterion(outputs, batch_y)
                
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            
            optimizer.step()
            train_losses.append(loss.item())

            # Calculate direction accuracy
            if has_current_price:
                with torch.no_grad():
                    true_direction = torch.sign(batch_y - current_prices)
                    pred_direction = torch.sign(outputs - current_prices)
                    
                    # Overall direction accuracy
                    direction_match = (true_direction == pred_direction).float()
                    direction_accuracy = direction_match.mean().item() * 100
                    direction_accuracies.append(direction_accuracy)
                    
                    # Long accuracy (when true direction is up)
                    long_mask = (true_direction > 0).squeeze()
                    if long_mask.any():
                        long_accuracy = direction_match[long_mask].mean().item() * 100
                        long_accuracies.append(long_accuracy)
                        
                    # Short accuracy (when true direction is down)
                    short_mask = (true_direction < 0).squeeze()
                    if short_mask.any():
                        short_accuracy = direction_match[short_mask].mean().item() * 100
                        short_accuracies.append(short_accuracy)

        avg_train_loss = np.mean(train_losses)
        avg_direction_accuracy = np.mean(direction_accuracies) if direction_accuracies else None
        avg_long_accuracy = np.mean(long_accuracies) if long_accuracies else None
        avg_short_accuracy = np.mean(short_accuracies) if short_accuracies else None

        # Validation phase
        model.eval()
        val_losses = []
        val_direction_accuracies = []
        val_long_accuracies = []
        val_short_accuracies = []
        
        with torch.no_grad():
            for batch in val_loader:
                if has_current_price:
                    batch_x, batch_y, current_prices = batch
                    batch_x = batch_x.to(device)
                    batch_y = batch_y.to(device).unsqueeze(1)
                    current_prices = current_prices.to(device).unsqueeze(1)
                else:
                    batch_x, batch_y = batch
                    batch_x = batch_x.to(device)
                    batch_y = batch_y.to(device).unsqueeze(1)
                    current_prices = None

                outputs = model(batch_x)
                
                # Calculate loss
                if has_current_price and isinstance(criterion, DirectionalPredictionLoss):
                    loss = criterion(outputs, batch_y, current_prices)
                else:
                    loss = criterion(outputs, batch_y)
                    
                val_losses.append(loss.item())
                
                # Calculate direction accuracy
                if has_current_price:
                    true_direction = torch.sign(batch_y - current_prices)
                    pred_direction = torch.sign(outputs - current_prices)
                    
                    # Overall direction accuracy
                    direction_match = (true_direction == pred_direction).float()
                    direction_accuracy = direction_match.mean().item() * 100
                    val_direction_accuracies.append(direction_accuracy)
                    
                    # Long accuracy
                    long_mask = (true_direction > 0).squeeze()
                    if long_mask.any():
                        long_accuracy = direction_match[long_mask].mean().item() * 100
                        val_long_accuracies.append(long_accuracy)
                        
                    # Short accuracy
                    short_mask = (true_direction < 0).squeeze()
                    if short_mask.any():
                        short_accuracy = direction_match[short_mask].mean().item() * 100
                        val_short_accuracies.append(short_accuracy)

        avg_val_loss = np.mean(val_losses)
        avg_val_direction_accuracy = np.mean(val_direction_accuracies) if val_direction_accuracies else None
        avg_val_long_accuracy = np.mean(val_long_accuracies) if val_long_accuracies else None
        avg_val_short_accuracy = np.mean(val_short_accuracies) if val_short_accuracies else None

        # Update training history
        if training_history:
            training_history.update(
                epoch, 
                avg_train_loss, 
                avg_val_loss,
                avg_direction_accuracy,
                avg_long_accuracy,
                avg_short_accuracy
            )

        # Learning rate scheduler step
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(avg_val_loss)
            else:
                scheduler.step()

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            # Save best model state
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                logger.info(f"Early stopping triggered at epoch {epoch}")
                break

        # Logging with direction-specific metrics
        log_msg = f"Epoch [{epoch}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}"
        
        if avg_direction_accuracy is not None:
            log_msg += f", Direction Acc: {avg_direction_accuracy:.2f}%"
            
        if avg_long_accuracy is not None and avg_short_accuracy is not None:
            log_msg += f", Long/Short Acc: {avg_long_accuracy:.2f}%/{avg_short_accuracy:.2f}%"
            
        logger.info(log_msg)

    # Restore best model
    if best_model_state:
        model.load_state_dict(best_model_state)
        model = model.to(device)
        
    return model

def evaluate_model(model: nn.Module, 
                  test_loader: DataLoader, 
                  device: torch.device,
                  return_predictions: bool = False,
                  evaluator: Optional[ModelEvaluator] = None,
                  ticker: str = None) -> Tuple:
    """Evaluate the trained model on test data with enhanced metrics and range clipping."""
    model.eval()
    y_true = []
    y_pred = []
    current_prices = []

    # Check if dataset provides current price
    sample_batch = next(iter(test_loader))
    has_current_price = len(sample_batch) == 3

    with torch.no_grad():
        for batch in test_loader:
            if has_current_price:
                batch_x, batch_y, batch_current = batch
                batch_x = batch_x.to(device)
                outputs = model(batch_x).cpu().numpy()
                y_true.extend(batch_y.numpy())
                y_pred.extend(outputs.flatten())
                current_prices.extend(batch_current.numpy())
            else:
                batch_x, batch_y = batch
                batch_x = batch_x.to(device)
                outputs = model(batch_x).cpu().numpy()
                y_true.extend(batch_y.numpy())
                y_pred.extend(outputs.flatten())

    # Convert to numpy arrays
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # Apply adaptive correction if evaluator is provided
    if evaluator is not None and ticker is not None:
        correction = evaluator.get_adaptive_correction(ticker)
        if correction != 0:
            logger.info(f"Applying adaptive correction of {correction:.4f} for {ticker}")
            y_pred = y_pred - correction
    
    if has_current_price:
        current_prices = np.array(current_prices)
        
        # Calculate prediction range statistics before clipping
        pred_range_before = np.max(y_pred) - np.min(y_pred)
        true_range = np.max(y_true) - np.min(y_true)
        range_ratio_before = pred_range_before / true_range if true_range > 0 else 0
        
        logger.info(f"Prediction range before clipping - True range: {true_range:.2f}, " 
                    f"Predicted range: {pred_range_before:.2f}, Ratio: {range_ratio_before:.2f}")
        
        # MODIFIED: Enforce strict range clipping to realistic minute-by-minute changes (±1.5%)
        max_pct_change = 0.015  # 1.5% maximum change per minute (strict limit)
        
        # Calculate reasonable prediction bounds
        min_bound = current_prices * (1 - max_pct_change)
        max_bound = current_prices * (1 + max_pct_change)
        
        # Clip predictions to reasonable range
        y_pred_before = y_pred.copy()  # Save original predictions
        y_pred = np.clip(y_pred, min_bound, max_bound)
        
        # Calculate how many predictions were clipped
        num_clipped = np.sum((y_pred != y_pred_before))
        pct_clipped = (num_clipped / len(y_pred)) * 100
        
        # Calculate average clipping amount
        if num_clipped > 0:
            avg_clip_amount = np.mean(np.abs(y_pred - y_pred_before)[y_pred != y_pred_before])
            avg_clip_pct = avg_clip_amount / np.mean(current_prices) * 100
        else:
            avg_clip_amount = 0
            avg_clip_pct = 0
        
        # Calculate prediction range statistics after clipping
        pred_range_after = np.max(y_pred) - np.min(y_pred)
        range_ratio_after = pred_range_after / true_range if true_range > 0 else 0
        
        logger.info(f"Prediction range after clipping - Predicted range: {pred_range_after:.2f}, Ratio: {range_ratio_after:.2f}")
        logger.info(f"Clipped {num_clipped} predictions ({pct_clipped:.2f}%), average clip amount: {avg_clip_amount:.4f} ({avg_clip_pct:.4f}%)")
    
        # Calculate prediction bias metrics before and after clipping
        pred_direction_before = np.sign(y_pred_before - current_prices)
        up_pred_pct_before = np.mean(pred_direction_before > 0) * 100
        down_pred_pct_before = np.mean(pred_direction_before < 0) * 100
        
        pred_direction_after = np.sign(y_pred - current_prices)
        up_pred_pct_after = np.mean(pred_direction_after > 0) * 100
        down_pred_pct_after = np.mean(pred_direction_after < 0) * 100
        
        logger.info(f"Direction bias before clipping - Up: {up_pred_pct_before:.2f}%, Down: {down_pred_pct_before:.2f}%")
        logger.info(f"Direction bias after clipping - Up: {up_pred_pct_after:.2f}%, Down: {down_pred_pct_after:.2f}%")
    
    # Add range clipping to the MultiTickerMonitor.update_ticker method
    if return_predictions:
        if has_current_price:
            return y_true, y_pred, current_prices
        else:
            return y_true, y_pred, None
    else:
        if has_current_price:
            return y_true, y_pred, current_prices
        else:
            return y_true, y_pred, None

def update_ticker(self, ticker: str, current_price: float,
                 predicted_price: float, timestamp: pd.Timestamp,
                 market_data: pd.DataFrame,
                 order_book_data: Dict = None,
                 prediction_metrics: Dict = None) -> Optional[Dict]:
    """Update ticker information and generate signals with range clipping."""
    if ticker not in self.tracked_tickers:
        self.add_ticker(ticker, current_price)

    self.tracked_tickers[ticker]['current_price'] = current_price
    self.tracked_tickers[ticker]['last_update'] = timestamp
    
    # Get prediction confidence for this ticker
    price_change = predicted_price - current_price
    direction = 'up' if price_change > 0 else 'down'
    
    # Use stored confidence or default
    confidence = self.ticker_confidence.get(ticker, self.default_confidence)
    
    # Update confidence if metrics provided
    if prediction_metrics:
        if 'up_direction_accuracy' in prediction_metrics:
            confidence['up'] = 0.5 + (prediction_metrics['up_direction_accuracy'] / 200)
        if 'down_direction_accuracy' in prediction_metrics:
            confidence['down'] = 0.5 + (prediction_metrics['down_direction_accuracy'] / 200)
    
    # MODIFIED: Apply range clipping to predicted price
    max_pct_change = 0.015  # 1.5% maximum change per minute
    min_bound = current_price * (1 - max_pct_change)
    max_bound = current_price * (1 + max_pct_change)
    
    original_prediction = predicted_price
    clipped_prediction = np.clip(predicted_price, min_bound, max_bound)
    
    # Log if clipping was applied
    if clipped_prediction != original_prediction:
        clipping_amount = abs(clipped_prediction - original_prediction)
        clipping_pct = (clipping_amount / current_price) * 100
        logger.info(f"Prediction clipped for {ticker}: {original_prediction:.4f} -> {clipped_prediction:.4f} ({clipping_pct:.4f}%)")
    
    # Use clipped prediction for signal generation
    signal = self.signal_generator.generate_signals(
        ticker=ticker,
        current_price=current_price,
        predicted_price=clipped_prediction,
        timestamp=timestamp,
        market_data=market_data,
        order_book_data=order_book_data,
        prediction_confidence=confidence
    )

    if signal and signal['action']:
        self.tracked_tickers[ticker]['signals'].append(signal)
        return signal

    return None

def backtest_trades_with_costs(signals: List[Dict], df: pd.DataFrame, 
                             commission_rate: float = 0.001, 
                             slippage_factor: float = 0.0002) -> Tuple[List[Dict], float, float]:
    """
    Execute backtesting with realistic transaction costs and slippage.
    
    Args:
        signals: List of trading signals
        df: DataFrame with market data
        commission_rate: Trading commission as a percentage
        slippage_factor: Price slippage as a percentage
        
    Returns:
        Tuple of trades list, total commission, and total slippage
    """
    trades = []
    entry_signal = None
    total_commission = 0
    total_slippage = 0
    
    if not isinstance(df.index, pd.DatetimeIndex):
        df.index = pd.to_datetime(df.index)

    for signal in signals:
        signal_time = pd.to_datetime(signal['timestamp'])
        
        # Apply slippage to entry and exit prices
        price_with_slippage = signal['price']
        slippage_amount = 0
        
        if signal['action'].startswith('enter_long') or signal['action'].startswith('exit_short'):
            # Buy operations - price moves against us (higher)
            slippage_amount = signal['price'] * slippage_factor
            price_with_slippage = signal['price'] + slippage_amount
        elif signal['action'].startswith('enter_short') or signal['action'].startswith('exit_long'):
            # Sell operations - price moves against us (lower)
            slippage_amount = signal['price'] * slippage_factor
            price_with_slippage = signal['price'] - slippage_amount
        
        # Track slippage
        total_slippage += slippage_amount
        
        # Calculate commission
        commission = price_with_slippage * commission_rate
        total_commission += commission
        
        # Process entry signals
        if signal['action'].startswith('enter_'):
            if entry_signal is None:
                # Use position_size from signal if available
                position_size = signal.get('position_size', 1.0)
                
                entry_signal = {
                    'timestamp': signal_time,
                    'price': price_with_slippage,
                    'action': signal['action'],
                    'commission': commission,
                    'slippage': slippage_amount,
                    'position_size': position_size
                }
        
        # Process exit signals
        elif signal['action'].startswith('exit_') and entry_signal is not None:
            exit_reason = signal.get('exit_reason', 'unspecified')
            duration = (signal_time - entry_signal['timestamp']).total_seconds() / 60
            
            # Calculate profit/loss including transaction costs
            total_trade_commission = entry_signal['commission'] + commission
            total_trade_slippage = entry_signal['slippage'] + slippage_amount
            position_size = entry_signal.get('position_size', 1.0)
            
            if entry_signal['action'] == 'enter_long':
                profit = (price_with_slippage - entry_signal['price'] - total_trade_commission) * position_size
            else:  # Short trade
                profit = (entry_signal['price'] - price_with_slippage - total_trade_commission) * position_size
            
            trade = {
                'entry_time': entry_signal['timestamp'],
                'entry_price': entry_signal['price'],
                'exit_time': signal_time,
                'exit_price': price_with_slippage,
                'position': entry_signal['action'],
                'position_size': position_size,
                'profit': profit,
                'transaction_costs': total_trade_commission,
                'slippage_costs': total_trade_slippage,
                'duration': duration,
                'exit_reason': exit_reason
            }
            trades.append(trade)
            entry_signal = None
    
    # Close any open position at the end of the dataset
    if entry_signal is not None:
        last_time = df.index[-1]
        last_price = df['close'].iloc[-1]
        
        # Apply slippage to last price
        slippage_amount = 0
        price_with_slippage = last_price
        
        if entry_signal['action'] == 'enter_short':
            # For short positions, buying to close (higher price)
            slippage_amount = last_price * slippage_factor
            price_with_slippage = last_price + slippage_amount
        else:
            # For long positions, selling to close (lower price)
            slippage_amount = last_price * slippage_factor
            price_with_slippage = last_price - slippage_amount
        
        # Track slippage
        total_slippage += slippage_amount
        
        # Calculate commission
        commission = price_with_slippage * commission_rate
        total_commission += commission
        
        duration = (last_time - entry_signal['timestamp']).total_seconds() / 60
        
        # Calculate profit/loss
        total_trade_commission = entry_signal['commission'] + commission
        total_trade_slippage = entry_signal['slippage'] + slippage_amount
        position_size = entry_signal.get('position_size', 1.0)
        
        if entry_signal['action'] == 'enter_long':
            profit = (price_with_slippage - entry_signal['price'] - total_trade_commission) * position_size
        else:  # Short trade
            profit = (entry_signal['price'] - price_with_slippage - total_trade_commission) * position_size
        
        trade = {
            'entry_time': entry_signal['timestamp'],
            'entry_price': entry_signal['price'],
            'exit_time': last_time,
            'exit_price': price_with_slippage,
            'position': entry_signal['action'],
            'position_size': position_size,
            'profit': profit,
            'transaction_costs': total_trade_commission,
            'slippage_costs': total_trade_slippage,
            'duration': duration,
            'exit_reason': 'end_of_data'
        }
        trades.append(trade)
    
    return trades, total_commission, total_slippage

def plot_candlestick_analysis(df: pd.DataFrame, signals: Optional[List[Dict]] = None, 
                            trades: Optional[List[Dict]] = None,
                            ticker: str = '') -> None:
    """
    Plot candlestick chart with trade signals and performance metrics.
    
    Args:
        df: DataFrame with OHLCV data
        signals: Optional list of trading signals
        trades: Optional list of executed trades
        ticker: Stock ticker symbol
    """
    try:
        df_plot = df.copy()
        df_plot = df_plot[['open', 'high', 'low', 'close', 'volume']]
        df_plot.index.name = 'Date'

        # Create subplots
        fig = plt.figure(figsize=(20, 12))
        gs = gridspec.GridSpec(3, 1, height_ratios=[3, 1, 1], hspace=0.05)
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1], sharex=ax1)
        ax3 = plt.subplot(gs[2], sharex=ax1)

        # Plot candlesticks and volume
        mc = mpf.make_marketcolors(up='g', down='r', inherit=True)
        s = mpf.make_mpf_style(marketcolors=mc)
        mpf.plot(df_plot, type='candle', style=s, ax=ax1, volume=False, show_nontrading=False)

        # Plot Volume
        df_plot['volume'].plot(ax=ax2, color='blue', alpha=0.5)
        ax2.set_ylabel('Volume')
        ax2.grid(True)
        
        # Plot additional indicator (e.g., RSI)
        if 'RSI' in df.columns:
            df['RSI'].plot(ax=ax3, color='purple')
            ax3.set_ylabel('RSI')
            ax3.grid(True)
            ax3.axhline(y=70, color='r', linestyle='--', alpha=0.5)
            ax3.axhline(y=30, color='g', linestyle='--', alpha=0.5)
        else:
            # Plot close price with SMA
            df['close'].plot(ax=ax3, color='blue')
            if 'SMA_20' in df.columns:
                df['SMA_20'].plot(ax=ax3, color='orange')
            ax3.set_ylabel('Price')
            ax3.grid(True)

        # Add signals if provided
        if signals:
            for signal in signals:
                if signal.get('action') is None:
                    continue
                    
                signal_time = pd.to_datetime(signal['timestamp'])
                price = signal['price']
                
                if 'enter_long' in signal['action']:
                    ax1.scatter(signal_time, price, 
                              marker='^', color='g', s=100, label='Buy Signal')
                elif 'enter_short' in signal['action']:
                    ax1.scatter(signal_time, price, 
                              marker='v', color='r', s=100, label='Sell Signal')
                elif 'exit_long' in signal['action'] or 'exit_short' in signal['action']:
                    ax1.scatter(signal_time, price, 
                              marker='o', color='k', s=100, label='Exit Signal')

        # Plot trades if available
        if trades:
            # Prepare data for trade annotations
            for trade in trades:
                entry_time = trade['entry_time']
                exit_time = trade['exit_time']
                entry_price = trade['entry_price']
                exit_price = trade['exit_price']
                profit = trade.get('profit', 0)
                position_size = trade.get('position_size', 1.0)
                
                # Choose color based on profit
                color = 'g' if profit > 0 else 'r'
                
                # Draw line connecting entry and exit
                ax1.plot([entry_time, exit_time], [entry_price, exit_price], 
                       color=color, linestyle='-', linewidth=1.5 * position_size, alpha=0.7)
                
                # Add profit annotation
                if 'profit' in trade:
                    mid_time = entry_time + (exit_time - entry_time) / 2
                    ax1.annotate(f"${profit:.2f}", 
                               xy=(mid_time, max(entry_price, exit_price)), 
                               xytext=(0, 5), textcoords='offset points',
                               fontsize=8, color=color)

        # Add legend
        handles, labels = ax1.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax1.legend(by_label.values(), by_label.keys())

        plt.title(f"Candlestick Chart with Trade Signals for {ticker}")
        plt.tight_layout()
        plt.show()

    except Exception as e:
        logger.error(f"Error in plot_candlestick_analysis: {e}")
        traceback.print_exc()

def plot_trading_metrics(metrics: Dict[str, float], ticker: str) -> None:
    """
    Create comprehensive visualization of trading metrics.
    
    Args:
        metrics: Dictionary of trading metrics
        ticker: Stock ticker symbol
    """
    plt.style.use('default')
    fig = plt.figure(figsize=(20, 12))
    gs = gridspec.GridSpec(2, 2)

    # Rates Comparison
    ax1 = fig.add_subplot(gs[0, 0])
    rates = ['win_rate', 'loss_rate', 'long_win_rate', 'short_win_rate']
    values = [metrics.get(rate, 0) for rate in rates]
    colors = ['green', 'red', 'lightgreen', 'lightcoral']
    ax1.bar(rates, values, color=colors)
    ax1.set_title('Trading Rates Comparison')
    ax1.set_ylabel('Percentage (%)')
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)

    # Profit Metrics
    ax2 = fig.add_subplot(gs[0, 1])
    profit_metrics = ['total_profit', 'average_profit_per_trade']
    values = [metrics.get(metric, 0) for metric in profit_metrics]
    ax2.bar(profit_metrics, values, color='blue')
    ax2.set_title('Profit Metrics')
    ax2.set_ylabel('Amount ($)')
    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)

    # Risk Metrics
    ax3 = fig.add_subplot(gs[1, 0])
    risk_metrics = ['maximum_drawdown', 'sharpe_ratio', 'profit_factor']
    values = [metrics.get(metric, 0) for metric in risk_metrics]
    ax3.bar(risk_metrics, values, color='purple')
    ax3.set_title('Risk Metrics')
    plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45)

    # Trade Analysis
    ax4 = fig.add_subplot(gs[1, 1])
    trade_metrics = ['total_trades', 'long_trades', 'short_trades', 'average_trade_duration']
    values = [metrics.get(metric, 0) for metric in trade_metrics]
    ax4.bar(trade_metrics, values, color='orange')
    ax4.set_title('Trade Analysis')
    plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45)

    plt.suptitle(f'Trading Performance Metrics for {ticker}', fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_learning_curves(training_history: TrainingHistory, ticker: str) -> None:
    """
    Plot learning curves showing training metrics over epochs.
    
    Args:
        training_history: Object containing training and validation metrics
        ticker: Stock ticker symbol for plot title
    """
    plt.figure(figsize=(18, 12))
    
    # Create subplot grid
    gs = gridspec.GridSpec(2, 2)
    
    # Plot training and validation loss
    ax1 = plt.subplot(gs[0, 0])
    epochs = range(1, len(training_history.loss_history) + 1)
    ax1.plot(epochs, training_history.loss_history, 'b-', label='Training Loss')
    ax1.plot(epochs, training_history.validation_loss_history, 'r-', label='Validation Loss')
    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    ax1.set_yscale('log')  # Use log scale for better visualization
    
    # Plot direction accuracy if available
    if hasattr(training_history, 'direction_accuracy_history') and training_history.direction_accuracy_history:
        ax2 = plt.subplot(gs[0, 1])
        ax2.plot(epochs, training_history.direction_accuracy_history, 'g-', label='Direction Accuracy')
        ax2.set_title('Direction Prediction Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.grid(True)
        ax2.legend()
        
    # Plot long/short accuracy if available
    if (hasattr(training_history, 'long_accuracy_history') and training_history.long_accuracy_history and
        hasattr(training_history, 'short_accuracy_history') and training_history.short_accuracy_history):
        ax3 = plt.subplot(gs[1, 0])
        ax3.plot(epochs, training_history.long_accuracy_history, 'g-', label='Long Accuracy')
        ax3.plot(epochs, training_history.short_accuracy_history, 'r-', label='Short Accuracy')
        ax3.set_title('Long vs. Short Prediction Accuracy')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Accuracy (%)')
        ax3.grid(True)
        ax3.legend()
        
        # Plot accuracy ratio (long/short)
        if len(training_history.long_accuracy_history) == len(training_history.short_accuracy_history):
            ax4 = plt.subplot(gs[1, 1])
            ratio = [l/s if s > 0 else 1.0 for l, s in zip(
                training_history.long_accuracy_history, 
                training_history.short_accuracy_history
            )]
            ax4.plot(epochs, ratio, 'b-', label='Long/Short Accuracy Ratio')
            ax4.axhline(y=1.0, color='k', linestyle='--', alpha=0.5)
            ax4.set_title('Direction Prediction Balance')
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('Ratio')
            ax4.grid(True)
            ax4.legend()
    
    plt.suptitle(f'Learning Curves for {ticker}', fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_prediction_analysis(y_true: np.ndarray, y_pred: np.ndarray, 
                          current_prices: Optional[np.ndarray] = None,
                          ticker: str = '') -> None:
    """
    Plot analysis of prediction performance with focus on directional accuracy.
    
    Args:
        y_true: True values
        y_pred: Predicted values
        current_prices: Current prices (for direction calculation)
        ticker: Stock ticker symbol
    """
    plt.figure(figsize=(20, 16))
    
    # Create subplot grid
    gs = gridspec.GridSpec(3, 2)
    
    # Plot 1: True vs Predicted prices
    ax1 = plt.subplot(gs[0, 0])
    ax1.scatter(y_true, y_pred, alpha=0.3)
    min_val = min(np.min(y_true), np.min(y_pred))
    max_val = max(np.max(y_true), np.max(y_pred))
    ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
    ax1.set_xlabel('True Price')
    ax1.set_ylabel('Predicted Price')
    ax1.set_title('True vs Predicted Prices')
    ax1.grid(True)
    
    # Plot 2: Prediction Error Distribution
    ax2 = plt.subplot(gs[0, 1])
    errors = y_pred - y_true
    mean_error = np.mean(errors)
    std_error = np.std(errors)
    ax2.hist(errors, bins=50, alpha=0.7, color='blue')
    ax2.axvline(x=0, color='r', linestyle='--')
    ax2.set_xlabel('Prediction Error')
    ax2.set_ylabel('Frequency')
    ax2.set_title(f'Error Distribution (Mean: {mean_error:.4f}, Std: {std_error:.4f})')
    ax2.grid(True)
    
    # Calculate direction accuracy if current prices are available
    if current_prices is not None:
        # Plot 3: Direction Accuracy
        ax3 = plt.subplot(gs[1, 0])
        
        # Calculate true and predicted directions
        true_directions = np.sign(y_true - current_prices)
        pred_directions = np.sign(y_pred - current_prices)
        
        # Calculate direction matches
        matches = (true_directions == pred_directions)
        
        # Separate into up and down movements
        up_indices = np.where(true_directions > 0)[0]
        down_indices = np.where(true_directions < 0)[0]
        
        # Calculate accuracies
        overall_accuracy = np.mean(matches) * 100
        up_accuracy = np.mean(matches[up_indices]) * 100 if len(up_indices) > 0 else 0
        down_accuracy = np.mean(matches[down_indices]) * 100 if len(down_indices) > 0 else 0
        
        # Create bar chart
        accuracies = [overall_accuracy, up_accuracy, down_accuracy]
        labels = ['Overall', 'Up Movement', 'Down Movement']
        colors = ['blue', 'green', 'red']
        
        ax3.bar(labels, accuracies, color=colors)
        ax3.set_ylabel('Direction Accuracy (%)')
        ax3.set_title(f'Direction Prediction Accuracy')
        ax3.grid(True)
        
        # Add text annotations
        for i, acc in enumerate(accuracies):
            ax3.text(i, acc + 1, f'{acc:.1f}%', ha='center')
            
        # Plot 4: Movement Distribution
        ax4 = plt.subplot(gs[1, 1])
        
        # Count movement types
        up_count = len(up_indices)
        down_count = len(down_indices)
        no_move_count = len(true_directions) - up_count - down_count
        
        # Create pie chart
        labels = ['Up', 'Down', 'No Change']
        sizes = [up_count, down_count, no_move_count]
        colors = ['green', 'red', 'gray']
        
        ax4.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
        ax4.axis('equal')
        ax4.set_title('True Price Movement Distribution')
        
        # Plot 5: Prediction Bias
        ax5 = plt.subplot(gs[2, 0])
        
        # Calculate predicted movement types
        pred_up = np.sum(pred_directions > 0)
        pred_down = np.sum(pred_directions < 0)
        pred_no_move = len(pred_directions) - pred_up - pred_down
        
        # Create side-by-side bar chart
        labels = ['Up', 'Down', 'No Change']
        true_counts = [up_count, down_count, no_move_count]
        pred_counts = [pred_up, pred_down, pred_no_move]
        
        x = np.arange(len(labels))
        width = 0.35
        
        ax5.bar(x - width/2, true_counts, width, label='True')
        ax5.bar(x + width/2, pred_counts, width, label='Predicted')
        
        ax5.set_xlabel('Movement Direction')
        ax5.set_ylabel('Count')
        ax5.set_title('True vs Predicted Movement Distribution')
        ax5.set_xticks(x)
        ax5.set_xticklabels(labels)
        ax5.legend()
        ax5.grid(True)
        
        # Plot 6: Error by True Direction
        ax6 = plt.subplot(gs[2, 1])
        
        # Separate errors by direction
        up_errors = errors[up_indices]
        down_errors = errors[down_indices]
        
        # Create box plot with updated parameter name
        data = [up_errors, down_errors]
        ax6.boxplot(data, tick_labels=['Up Movement', 'Down Movement'])
        ax6.axhline(y=0, color='r', linestyle='--')
        ax6.set_ylabel('Prediction Error')
        ax6.set_title('Error Distribution by True Direction')
        ax6.grid(True)
    
    plt.suptitle(f'Prediction Analysis for {ticker}', fontsize=16)
    plt.tight_layout()
    plt.show()

def validate_model_quality(evaluation_metrics: Dict[str, float], ticker: str) -> bool:
    """
    Validate model quality to ensure it meets minimum standards before generating signals.
    
    Args:
        evaluation_metrics: Dictionary containing model evaluation metrics
        ticker: Stock ticker symbol
        
    Returns:
        Boolean indicating whether model meets quality standards
    """
    # Define minimum quality thresholds
    min_direction_accuracy = 50.0  # Direction accuracy should be better than random
    max_direction_bias = 15.0      # Directional bias should not be too extreme
    max_range_ratio = 2.0          # Prediction range should be reasonable
    
    # Check direction accuracy
    direction_accuracy = evaluation_metrics.get('direction_accuracy', 0)
    if direction_accuracy < min_direction_accuracy:
        logger.warning(f"Model for {ticker} failed validation: direction accuracy {direction_accuracy:.2f}% < {min_direction_accuracy}%")
        return False
    
    # Check for excessive directional bias
    direction_bias = abs(evaluation_metrics.get('direction_bias', 0))
    if direction_bias > max_direction_bias:
        logger.warning(f"Model for {ticker} failed validation: excessive direction bias {direction_bias:.2f}% > {max_direction_bias}%")
        return False
    
    # Check for unreasonable prediction range
    range_ratio = evaluation_metrics.get('range_ratio', 1.0)
    if range_ratio > max_range_ratio:
        logger.warning(f"Model for {ticker} failed validation: unreasonable prediction range ratio {range_ratio:.2f} > {max_range_ratio}")
        return False
    
    # Check balance between up/down direction accuracies
    up_accuracy = evaluation_metrics.get('up_direction_accuracy', 0)
    down_accuracy = evaluation_metrics.get('down_direction_accuracy', 0)
    
    if up_accuracy > 0 and down_accuracy > 0:
        accuracy_ratio = up_accuracy / down_accuracy
        if accuracy_ratio > 1.5 or accuracy_ratio < 0.67:
            logger.warning(f"Model for {ticker} failed validation: imbalanced direction accuracies (up: {up_accuracy:.2f}%, down: {down_accuracy:.2f}%)")
            return False
    
    logger.info(f"Model for {ticker} passed validation criteria")
    return True

def get_optimizer_and_scheduler(model: nn.Module, params: Dict = None) -> Tuple[
    torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
    """
    Create optimizer and learning rate scheduler with enhanced parameters.
    
    Args:
        model: PyTorch model
        params: Optional parameters dictionary
    
    Returns:
        Tuple of optimizer and scheduler
    """
    if params is None:
        params = {}
    
    # Extract parameters with defaults
    lr = params.get('learning_rate', 0.001)
    weight_decay = params.get('weight_decay', 1e-2)  # Increased from 5e-3
    scheduler_factor = params.get('scheduler_factor', 0.3)  # Increased from 0.4
    scheduler_patience = params.get('scheduler_patience', 2)  # Decreased from 3
    min_lr = params.get('min_lr', 1e-6)  # Decreased from 1e-5
    
    # Create optimizer with gradient clipping
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Create scheduler with early detection of plateaus
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=scheduler_factor,
        patience=scheduler_patience,
        verbose=True,
        min_lr=min_lr,
        threshold=0.0001,  # More sensitive threshold
        threshold_mode='rel'
    )
    
    return optimizer, scheduler

def perform_time_series_cross_validation(
    data_df: pd.DataFrame, 
    model_class: nn.Module, 
    seq_length: int, 
    n_splits: int = 3, 
    test_size: int = 1000,
    device: torch.device = torch.device('cpu'),
    model_params: Dict = None) -> Tuple[List[float], List[Dict], nn.Module]:
    """
    Perform time series cross-validation for model evaluation.
    
    Args:
        data_df: DataFrame with features and targets
        model_class: PyTorch model class to use
        seq_length: Sequence length for LSTM
        n_splits: Number of splits for cross-validation
        test_size: Size of each test set
        device: Computation device (CPU/GPU)
        model_params: Dictionary with model parameters
        
    Returns:
        Tuple of validation accuracies, metrics per fold, and best model
    """
    if model_params is None:
        model_params = {
            'hidden_size': 256,
            'num_layers': 2,
            'dropout': 0.5
        }
    
    # Define validation metrics storage
    val_accuracies = []
    fold_metrics = []
    best_model = None
    best_accuracy = 0
    
    # Prepare cross-validation indices - forward chaining
    total_samples = len(data_df)
    fold_size = (total_samples - test_size) // n_splits
    
    for fold in range(n_splits):
        logger.info(f"Processing fold {fold+1}/{n_splits}")
        
        # Define train/test indices for this fold
        if fold < n_splits - 1:
            train_end = test_size + fold_size * (fold + 1)
            train_indices = range(0, train_end)
            test_indices = range(train_end, train_end + test_size)
        else:
            # Last fold uses all remaining data
            train_indices = range(0, total_samples - test_size)
            test_indices = range(total_samples - test_size, total_samples)
        
        # Create datasets for this fold
        train_data = data_df.iloc[train_indices].copy()
        test_data = data_df.iloc[test_indices].copy()
        
        # Create HFTDataset objects
        train_dataset = HFTDataset(
            data=train_data,
            sequence_length=seq_length,
            include_current_price=True
        )
        
        test_dataset = HFTDataset(
            data=test_data,
            sequence_length=seq_length,
            scaler=train_dataset.scaler,  # Use same scaler
            include_current_price=True
        )
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=64)
        
        # Initialize model
        input_size = len(train_dataset.feature_names)
        model = model_class(
            input_size=input_size,
            hidden_size=model_params['hidden_size'],
            num_layers=model_params['num_layers'],
            dropout=model_params['dropout']
        )
        
        # Initialize training components
        criterion = DirectionalPredictionLoss(
            direction_weight=2.5,
            magnitude_weight=0.8,
            short_penalty_multiplier=3.5,
            bias_correction_weight=1.5
        )
        
        optimizer = optim.Adam(
            model.parameters(), 
            lr=0.001,
            weight_decay=1e-2
        )
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.3,
            patience=2,
            verbose=True,
            min_lr=1e-6
        )
        
        # Train model
        training_history = TrainingHistory()
        model = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=50,  # Reduced for cross-validation
            device=device,
            training_history=training_history,
            early_stopping_patience=5,  # Reduced for cross-validation
            scheduler=scheduler
        )
        
        # Evaluate model
        evaluator = ModelEvaluator()
        y_true, y_pred, current_prices = evaluate_model(
            model, 
            test_loader, 
            device,
            return_predictions=True,
            evaluator=evaluator
        )
        
        # Calculate metrics
        metrics = evaluator.calculate_metrics(
            y_true=y_true,
            y_pred=y_pred,
            current_prices=current_prices
        )
        
        # Store metrics
        fold_metrics.append(metrics)
        direction_accuracy = metrics.get('direction_accuracy', 0)
        val_accuracies.append(direction_accuracy)
        
        # Check if this is the best model so far
        if direction_accuracy > best_accuracy:
            best_accuracy = direction_accuracy
            best_model = model.state_dict().copy()
            
        logger.info(f"Fold {fold+1} direction accuracy: {direction_accuracy:.2f}%")
        
    # Log average metrics across folds
    avg_accuracy = np.mean(val_accuracies)
    logger.info(f"Average direction accuracy across {n_splits} folds: {avg_accuracy:.2f}%")
    
    # Load best model weights into the model
    if best_model is not None:
        model.load_state_dict(best_model)
        
    return val_accuracies, fold_metrics, model

def optimize_trading_parameters(ticker: str, stock_df: pd.DataFrame, market_index_df: pd.DataFrame,
                              device: torch.device, evaluator: ModelEvaluator, n_trials: int = 30):
    """
    Use time-series cross-validation to optimize trading parameters.
    
    Args:
        ticker: Stock ticker symbol
        stock_df: DataFrame with stock price data
        market_index_df: DataFrame with market index data
        device: PyTorch device (CPU/GPU)
        evaluator: ModelEvaluator instance
        n_trials: Number of optimization trials
        
    Returns:
        Dictionary of optimized parameters
    """
    import optuna
    from functools import partial
    
    # Prepare data with feature engineering
    data_df = enhance_features(stock_df, market_index_df)
    
    # Create time-series split for cross-validation
    # Use 5 folds with forward chaining
    n_splits = 3
    test_size = len(data_df) // 8  # 12.5% of data for testing in each fold
    
    def time_series_split(df, n_splits, test_size):
        """Create time-series splits for cross-validation."""
        splits = []
        total_size = len(df)
        fold_size = (total_size - test_size) // n_splits
        
        for fold in range(n_splits):
            if fold < n_splits - 1:
                train_end = test_size + fold_size * (fold + 1)
                train_indices = range(0, train_end)
                test_indices = range(train_end, train_end + test_size)
            else:
                # Last fold uses all remaining data
                train_indices = range(0, total_size - test_size)
                test_indices = range(total_size - test_size, total_size)
            
            splits.append((train_indices, test_indices))
        
        return splits
    
    ts_splits = time_series_split(data_df, n_splits, test_size)
    
    def objective(trial, data_df, ts_splits, ticker, device, evaluator):
        """Optuna objective function for parameter optimization."""
        # Parameters to optimize
        params = {
            # Model parameters
            'dropout': trial.suggest_float('dropout', 0.3, 0.7),
            'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True),
            'weight_decay': trial.suggest_float('weight_decay', 1e-4, 1e-2, log=True),
            
            # Loss function parameters
            'direction_weight': trial.suggest_float('direction_weight', 1.5, 3.5),
            'magnitude_weight': trial.suggest_float('magnitude_weight', 0.5, 1.2),
            'short_penalty_multiplier': trial.suggest_float('short_penalty_multiplier', 1.0, 2.0),
            'bias_correction_weight': trial.suggest_float('bias_correction_weight', 0.5, 1.2),
            
            # Signal generator parameters
            'base_entry_threshold': trial.suggest_float('base_entry_threshold', 0.0005, 0.0025),
            'short_entry_threshold_factor': trial.suggest_float('short_entry_threshold_factor', 0.7, 1.0),
            'base_exit_threshold': trial.suggest_float('base_exit_threshold', 0.0004, 0.0020),
            'base_stop_loss': trial.suggest_float('base_stop_loss', 0.0008, 0.0040),
            'atr_multiplier': trial.suggest_float('atr_multiplier', 0.8, 1.8),
            
            # Signal balance parameters
            'downward_bias_correction_high': trial.suggest_float('downward_bias_correction_high', 0.001, 0.004),
            'downward_bias_correction_low': trial.suggest_float('downward_bias_correction_low', 0.0005, 0.002),
            'adaptive_factor_cap': trial.suggest_float('adaptive_factor_cap', 0.001, 0.003),
        }
        
        # Fixed parameters
        params.update({
            'sequence_length': 60,
            'hidden_size': 256,
            'num_layers': 2,
            'batch_size': 64,
            'num_epochs': 50,  # Reduced for optimization runs
            'early_stopping_patience': 10,
        })
        
        # Cross-validation scores
        cv_direction_accuracies = []
        cv_direction_balances = []
        cv_trade_metrics = []
        
        for fold, (train_idx, test_idx) in enumerate(ts_splits):
            try:
                train_data = data_df.iloc[train_idx].copy()
                test_data = data_df.iloc[test_idx].copy()
                
                # Create datasets
                train_dataset = HFTDataset(
                    data=train_data,
                    sequence_length=params['sequence_length'],
                    include_current_price=True,
                    ticker=ticker
                )
                
                test_dataset = HFTDataset(
                    data=test_data,
                    sequence_length=params['sequence_length'],
                    scaler=train_dataset.scaler,
                    include_current_price=True,
                    ticker=ticker
                )
                
                # Create data loaders
                train_loader = DataLoader(
                    train_dataset, 
                    batch_size=params['batch_size'],
                    shuffle=True
                )
                
                test_loader = DataLoader(
                    test_dataset,
                    batch_size=params['batch_size']
                )
                
                # Initialize model
                input_size = len(train_dataset.feature_names)
                model = LSTMModel(
                    input_size=input_size,
                    hidden_size=params['hidden_size'],
                    num_layers=params['num_layers'],
                    dropout=params['dropout']
                )
                
                # Initialize loss function with optimized parameters
                criterion = DirectionalPredictionLoss(
                    direction_weight=params['direction_weight'],
                    magnitude_weight=params['magnitude_weight'],
                    short_penalty_multiplier=params['short_penalty_multiplier'],
                    bias_correction_weight=params['bias_correction_weight']
                )
                criterion.set_ticker(ticker)
                
                # Initialize optimizer
                optimizer = optim.Adam(
                    model.parameters(),
                    lr=params['learning_rate'],
                    weight_decay=params['weight_decay']
                )
                
                # Initialize scheduler
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    mode='min',
                    factor=0.3,
                    patience=3,
                    verbose=False
                )
                
                # Train model
                model = train_model(
                    model=model,
                    train_loader=train_loader,
                    val_loader=test_loader,
                    criterion=criterion,
                    optimizer=optimizer,
                    num_epochs=params['num_epochs'],
                    device=device,
                    training_history=None,
                    early_stopping_patience=params['early_stopping_patience'],
                    scheduler=scheduler,
                    ticker=ticker
                )
                
                # Evaluate model
                y_true, y_pred, current_prices = evaluate_model(
                    model=model,
                    test_loader=test_loader,
                    device=device,
                    return_predictions=True,
                    evaluator=evaluator,
                    ticker=ticker
                )
                
                # Apply range clipping
                max_pct_change = 0.015
                min_bound = current_prices * (1 - max_pct_change)
                max_bound = current_prices * (1 + max_pct_change)
                y_pred = np.clip(y_pred, min_bound, max_bound)
                
                # Calculate metrics
                metrics = evaluator.calculate_metrics(
                    y_true=y_true,
                    y_pred=y_pred,
                    current_prices=current_prices,
                    ticker=ticker
                )
                
                # Initialize signal generator with optimized parameters
                signal_generator = AdaptiveSignalGenerator(ticker)
                signal_generator.params.update({
                    'base_entry_threshold': params['base_entry_threshold'],
                    'short_entry_threshold_factor': params['short_entry_threshold_factor'],
                    'base_exit_threshold': params['base_exit_threshold'],
                    'base_stop_loss': params['base_stop_loss'],
                    'atr_multiplier': params['atr_multiplier'],
                })
                
                # Initialize monitor
                monitor = MultiTickerMonitor(signal_generator)
                
                # Generate signals
                all_signals = []
                for idx in range(len(y_true)):
                    timestamp = test_data.index[idx]
                    actual = y_true[idx]
                    predicted = y_pred[idx]
                    current_price = current_prices[idx]
                    
                    # Get market data window
                    window_start_idx = max(0, test_idx[0] + idx - 20)
                    window_end_idx = test_idx[0] + idx + 1
                    market_data_window = data_df.iloc[window_start_idx:window_end_idx].copy()
                    
                    # Apply custom bias correction settings
                    market_regime = signal_generator.detect_market_regime(market_data_window)
                    price_changes = market_data_window['close'].pct_change().dropna()
                    mean_price_change = price_changes[-20:].mean() if len(price_changes) >= 20 else 0
                    
                    # Apply optimized bias correction
                    if market_regime in ['trending_up', 'high_volatility']:
                        downward_bias_correction = current_price * params['downward_bias_correction_high']
                    else:
                        downward_bias_correction = current_price * params['downward_bias_correction_low']
                    
                    if mean_price_change > 0:
                        adaptive_factor = min(params['adaptive_factor_cap'], mean_price_change)
                        downward_bias_correction += current_price * adaptive_factor
                    
                    corrected_prediction = predicted - downward_bias_correction
                    
                    # Generate signal with corrected prediction
                    signal = monitor.update_ticker(
                        ticker=ticker,
                        current_price=actual,
                        predicted_price=corrected_prediction,
                        timestamp=timestamp,
                        market_data=market_data_window,
                        prediction_metrics=metrics
                    )
                    
                    if signal and signal['action']:
                        all_signals.append(signal)
                
                # Generate trades
                trades, _, _ = backtest_trades_with_costs(
                    all_signals,
                    test_data,
                    commission_rate=0.001,
                    slippage_factor=0.0002
                )
                
                # Calculate trade metrics
                if trades:
                    trade_metrics = evaluator.calculate_trade_performance_metrics(
                        trades=trades,
                        initial_capital=10000.0,
                        include_transaction_costs=True
                    )
                    
                    # Balance metrics
                    long_trades = trade_metrics.get('long_trades', 0)
                    short_trades = trade_metrics.get('short_trades', 0)
                    total_trades = long_trades + short_trades
                    
                    if total_trades > 0:
                        long_pct = (long_trades / total_trades) * 100
                        short_pct = (short_trades / total_trades) * 100
                        trade_balance = abs(50 - long_pct)  # 0 means perfect balance
                    else:
                        trade_balance = 100  # Worst possible balance
                else:
                    trade_metrics = {
                        'total_trades': 0,
                        'win_rate': 0,
                        'long_trades': 0,
                        'short_trades': 0,
                        'total_profit': 0
                    }
                    trade_balance = 100
                
                # Store metrics
                cv_direction_accuracies.append(metrics.get('direction_accuracy', 0))
                cv_direction_balances.append(abs(metrics.get('up_direction_accuracy', 50) - metrics.get('down_direction_accuracy', 50)))
                cv_trade_metrics.append(trade_metrics)
                
            except Exception as e:
                logger.error(f"Error in optimization fold {fold}: {e}")
                # Return poor score for failed trials
                return -1000
        
        # Calculate average metrics across folds
        avg_direction_accuracy = np.mean(cv_direction_accuracies)
        avg_direction_balance = np.mean(cv_direction_balances)
        
        # Calculate average trade metrics
        total_trades = sum(m['total_trades'] for m in cv_trade_metrics) / len(cv_trade_metrics)
        avg_win_rate = np.mean([m['win_rate'] for m in cv_trade_metrics if m['total_trades'] > 0] or [0])
        
        # Calculate trade balance
        long_trades = sum(m['long_trades'] for m in cv_trade_metrics)
        short_trades = sum(m['short_trades'] for m in cv_trade_metrics)
        total_trade_count = long_trades + short_trades
        
        if total_trade_count > 0:
            long_pct = (long_trades / total_trade_count) * 100
            trade_balance_score = abs(50 - long_pct)  # 0 means perfect balance
        else:
            trade_balance_score = 100
        
        # Calculate combined score
        # We want: high direction accuracy, low direction imbalance, more trades, higher win rate, balanced trade types
        score = (
            avg_direction_accuracy * 10 
            - avg_direction_balance * 5 
            + min(100, total_trades) * 0.5
            + avg_win_rate * 2
            - trade_balance_score * 5
        )
        
        # Log trial results
        logger.info(f"Trial {trial.number} - Score: {score:.2f}, Direction Accuracy: {avg_direction_accuracy:.2f}%, "
                   f"Trades: {total_trades:.1f}, Win Rate: {avg_win_rate:.2f}%, "
                   f"Long/Short: {long_trades}/{short_trades}")
        
        return score
    
    # Create Optuna study
    logger.info(f"Starting parameter optimization for {ticker} with {n_trials} trials")
    study = optuna.create_study(direction='maximize')
    objective_func = partial(objective, data_df=data_df, ts_splits=ts_splits, 
                            ticker=ticker, device=device, evaluator=evaluator)
    
    # Run optimization
    study.optimize(objective_func, n_trials=n_trials)
    
    # Get best parameters
    best_params = study.best_params
    best_score = study.best_value
    
    # Log best parameters
    logger.info(f"Optimization completed for {ticker}")
    logger.info(f"Best score: {best_score:.2f}")
    logger.info(f"Best parameters: {best_params}")
    
    # Add fixed parameters
    best_params.update({
        'sequence_length': 60,
        'hidden_size': 256,
        'num_layers': 2,
        'batch_size': 64,
        'num_epochs': 100,  # Restore full epochs for final training
        'early_stopping_patience': 15,
    })
    
    return best_params


def main():
    """Main execution function for the high-frequency trading system with optimization and improved model performance."""
    
    # Configuration
    TICKERS = ['MSFT', 'GOOGL', 'TSLA', 'NVDA', 'TQQQ', 'SQQQ', 'QQQ', 'PSQ', 'QLD']
    MARKET_INDEX = 'SPY'
    END_DATE = '2025-03-14'  # Current date
    
    # Set the number of days for training, validation, and testing
    TRAINING_DAYS = 90
    
    # Set hyperparameters (these will be overridden by optimization if enabled)
    MODEL_PARAMS = {
        'sequence_length': 60,
        'hidden_size': 256,
        'num_layers': 2,
        'batch_size': 64,
        'num_epochs': 100,
        'learning_rate': 0.001,
        'dropout': 0.5,
        'early_stopping_patience': 15
    }
    
    # Transaction cost parameters for realistic backtesting
    TRANSACTION_COSTS = {
        'commission_rate': 0.001,  # 0.1% commission
        'slippage_factor': 0.0002  # 0.02% slippage
    }
    
    # Optimization control
    ENABLE_OPTIMIZATION = True  # Set to False to skip optimization
    OPTIMIZATION_TRIALS = 3    # Number of trials for optimization

    # Initialize components
    fetcher = PolygonDataFetcher(API_KEY)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    evaluator = ModelEvaluator()

    # Define date range
    end_date = datetime.strptime(END_DATE, '%Y-%m-%d')
    start_date = end_date - timedelta(days=TRAINING_DAYS)
    start_date_str = start_date.strftime('%Y-%m-%d')
    end_date_str = end_date.strftime('%Y-%m-%d')
    
    # Fetch market index data for context
    logger.info(f"Fetching market index data from {start_date_str} to {end_date_str}")
    market_index_df = fetcher.fetch_market_index_data(
        index_symbol=MARKET_INDEX,
        start_date=start_date_str,
        end_date=end_date_str
    )
    
    if market_index_df.empty:
        logger.warning(f"No market index data available. Proceeding without market context.")

    # Create output directories if they don't exist
    os.makedirs('balanced_lstm_models', exist_ok=True)
    os.makedirs('plots', exist_ok=True)
    
    # Process each ticker
    for ticker in TICKERS:
        try:
            logger.info(f"Processing {ticker}")
            print(f"\nProcessing {ticker}...")
            
            # Create ticker-specific output directory
            ticker_dir = f'balanced_lstm_models/{ticker}/'
            os.makedirs(ticker_dir, exist_ok=True)

            # Fetch and prepare market data
            logger.info(f"Fetching market data from {start_date_str} to {end_date_str}")
            stock_df = fetcher.fetch_stock_data(ticker, start_date_str, end_date_str)

            if stock_df.empty:
                logger.warning(f"No data fetched for {ticker}. Skipping.")
                continue

            # Apply enhanced feature engineering
            stock_df = enhance_features(stock_df, market_index_df)
            logger.info(f"Enhanced features created for {ticker}")
            
            # Create balanced dataset with stratified sampling
            balanced_df = create_stratified_dataset(
                stock_df, 
                min_regime_count=50,
                min_total_samples=1000,
                max_sampling_fraction=0.7,
                max_oversample_factor=5.0
            )
            logger.info(f"Created balanced dataset with {len(balanced_df)} samples")
            
            # Check if stratified dataset has sufficient data
            if len(balanced_df) < 200:
                logger.warning(f"WARNING: Stratified dataset for {ticker} has only {len(balanced_df)} samples")
                
                # Fallback to original dataset if samples are extremely low
                if len(balanced_df) < 100:
                    logger.warning(f"Using original dataset instead for {ticker} due to insufficient stratified samples")
                    balanced_df = stock_df
                    logger.info(f"Using original dataset with {len(balanced_df)} samples")
            
            # Initialize stock-specific components
            signal_generator = AdaptiveSignalGenerator(ticker)
            
            # Hyperparameter Optimization with Time-Series Cross-Validation
            if ENABLE_OPTIMIZATION:
                print(f"Optimizing trading parameters for {ticker}...")
                try:
                    optimized_params = optimize_trading_parameters(
                        ticker=ticker,
                        stock_df=stock_df,
                        market_index_df=market_index_df,
                        device=device,
                        evaluator=evaluator,
                        n_trials=OPTIMIZATION_TRIALS
                    )
                    
                    # Update model parameters with optimized values
                    MODEL_PARAMS.update({
                        'dropout': optimized_params['dropout'],
                        'learning_rate': optimized_params['learning_rate'],
                        'weight_decay': optimized_params['weight_decay']
                    })
                    
                    # Update loss function parameters (will be used when creating criterion)
                    loss_params = {
                        'direction_weight': optimized_params['direction_weight'],
                        'magnitude_weight': optimized_params['magnitude_weight'],
                        'short_penalty_multiplier': optimized_params['short_penalty_multiplier'],
                        'bias_correction_weight': optimized_params['bias_correction_weight']
                    }
                    
                    # Update signal generator parameters
                    signal_generator.params.update({
                        'base_entry_threshold': optimized_params['base_entry_threshold'],
                        'short_entry_threshold_factor': optimized_params['short_entry_threshold_factor'],
                        'base_exit_threshold': optimized_params['base_exit_threshold'],
                        'base_stop_loss': optimized_params['base_stop_loss'],
                        'atr_multiplier': optimized_params['atr_multiplier']
                    })
                    
                    # Set bias correction parameters for signal generation later
                    bias_correction_params = {
                        'downward_bias_correction_high': optimized_params['downward_bias_correction_high'],
                        'downward_bias_correction_low': optimized_params['downward_bias_correction_low'],
                        'adaptive_factor_cap': optimized_params['adaptive_factor_cap']
                    }
                    
                    # Save optimized parameters
                    params_filename = f'{ticker_dir}/optimized_params_{ticker}_{end_date_str}.pkl'
                    with open(params_filename, 'wb') as f:
                        pickle.dump(optimized_params, f)
                    
                    print(f"Optimized parameters saved to {params_filename}")
                    logger.info(f"Optimization completed successfully with score: {optimized_params.get('score', 'N/A')}")
                    
                except Exception as e:
                    logger.error(f"Error during parameter optimization: {e}")
                    print(f"Optimization failed with error: {str(e)}. Using default parameters.")
                    # Continue with default parameters if optimization fails
                    loss_params = {
                        'direction_weight': 2.5,
                        'magnitude_weight': 0.8, 
                        'short_penalty_multiplier': 1.5,
                        'bias_correction_weight': 0.8
                    }
                    bias_correction_params = {
                        'downward_bias_correction_high': 0.0025,
                        'downward_bias_correction_low': 0.001,
                        'adaptive_factor_cap': 0.0015
                    }
            else:
                # Use default parameters without optimization
                loss_params = {
                    'direction_weight': 2.5,
                    'magnitude_weight': 0.8, 
                    'short_penalty_multiplier': 1.5,
                    'bias_correction_weight': 0.8
                }
                bias_correction_params = {
                    'downward_bias_correction_high': 0.0025,
                    'downward_bias_correction_low': 0.001,
                    'adaptive_factor_cap': 0.0015
                }
            
            # Initialize monitor with configured signal generator
            monitor = MultiTickerMonitor(signal_generator)
            
            # Prepare datasets with current price for directional loss
            dataset = HFTDataset(
                data=balanced_df,
                sequence_length=MODEL_PARAMS['sequence_length'],
                scaler=StandardScaler(),
                include_current_price=True,
                relative_normalization=True,
                ticker=ticker
            )
            logger.info(f"Dataset created with {len(dataset)} sequences and {len(dataset.feature_names)} features")

            # Split data with proper sizes
            train_size = int(0.7 * len(dataset))
            val_size = int(0.15 * len(dataset))
            test_size = len(dataset) - train_size - val_size

            if train_size <= 0 or val_size <= 0 or test_size <= 0:
                logger.warning(f"Insufficient data for {ticker} after preprocessing. Skipping.")
                continue

            # Set random seed before splitting for reproducibility
            torch.manual_seed(42)
            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
                dataset, [train_size, val_size, test_size]
            )

            # Create data loaders
            train_loader = DataLoader(
                train_dataset,
                batch_size=MODEL_PARAMS['batch_size'],
                shuffle=True,
                drop_last=False,
                num_workers=4 if torch.cuda.is_available() else 0
            )
            val_loader = DataLoader(
                val_dataset,
                batch_size=MODEL_PARAMS['batch_size'],
                drop_last=False,
                num_workers=4 if torch.cuda.is_available() else 0
            )
            test_loader = DataLoader(
                test_dataset,
                batch_size=MODEL_PARAMS['batch_size'],
                drop_last=False,
                num_workers=4 if torch.cuda.is_available() else 0
            )

            # For problematic tickers, use the ensemble approach
            if ticker in ['MSFT', 'GOOGL', 'NVDA', 'TQQQ']:
                print(f"Using ensemble approach for {ticker}...")
                
                # Define model configurations for ensemble with improved parameters
                model_configs = [
                    # Base configuration with optimized parameters
                    {
                        'hidden_size': MODEL_PARAMS['hidden_size'],
                        'num_layers': MODEL_PARAMS['num_layers'],
                        'dropout': MODEL_PARAMS['dropout'],
                        'learning_rate': MODEL_PARAMS['learning_rate'],
                        'weight_decay': MODEL_PARAMS.get('weight_decay', 1e-2),
                        'direction_weight': loss_params['direction_weight'],
                        'magnitude_weight': loss_params['magnitude_weight'],
                        'short_penalty_multiplier': loss_params['short_penalty_multiplier'],
                        'bias_correction_weight': loss_params['bias_correction_weight'],
                        'lr_factor': 0.3,
                        'lr_patience': 2,
                        'num_epochs': MODEL_PARAMS['num_epochs'],
                        'early_stopping_patience': MODEL_PARAMS['early_stopping_patience']
                    },
                    # Variation with higher dropout
                    {
                        'hidden_size': MODEL_PARAMS['hidden_size'],
                        'num_layers': MODEL_PARAMS['num_layers'],
                        'dropout': min(0.7, MODEL_PARAMS['dropout'] * 1.2),  # Increase dropout but cap at 0.7
                        'learning_rate': MODEL_PARAMS['learning_rate'],
                        'weight_decay': MODEL_PARAMS.get('weight_decay', 1.5e-2),
                        'direction_weight': loss_params['direction_weight'] * 1.1,
                        'magnitude_weight': loss_params['magnitude_weight'] * 0.9,
                        'short_penalty_multiplier': loss_params['short_penalty_multiplier'],
                        'bias_correction_weight': loss_params['bias_correction_weight'],
                        'lr_factor': 0.3,
                        'lr_patience': 2,
                        'num_epochs': MODEL_PARAMS['num_epochs'],
                        'early_stopping_patience': MODEL_PARAMS['early_stopping_patience']
                    },
                    # Variation with different architecture
                    {
                        'hidden_size': MODEL_PARAMS['hidden_size'] * 2,
                        'num_layers': 1,
                        'dropout': MODEL_PARAMS['dropout'],
                        'learning_rate': MODEL_PARAMS['learning_rate'] * 0.5,
                        'weight_decay': MODEL_PARAMS.get('weight_decay', 1e-2),
                        'direction_weight': loss_params['direction_weight'],
                        'magnitude_weight': loss_params['magnitude_weight'],
                        'short_penalty_multiplier': loss_params['short_penalty_multiplier'],
                        'bias_correction_weight': loss_params['bias_correction_weight'],
                        'lr_factor': 0.3,
                        'lr_patience': 2,
                        'num_epochs': MODEL_PARAMS['num_epochs'],
                        'early_stopping_patience': MODEL_PARAMS['early_stopping_patience']
                    }
                ]
                
                # Train ensemble models
                ensemble_models = []
                ensemble_weights = []
                best_validation_score = 0
                best_model = None
                
                # Initialize training history for plotting
                training_history = TrainingHistory()
                
                for i, config in enumerate(model_configs):
                    logger.info(f"Training ensemble model {i+1}/{len(model_configs)} with config: {config}")
                    print(f"Training ensemble model {i+1}/{len(model_configs)}...")
                    
                    # Initialize model
                    input_size = len(dataset.feature_names)
                    model = LSTMModel(
                        input_size=input_size,
                        hidden_size=config['hidden_size'],
                        num_layers=config['num_layers'],
                        dropout=config['dropout']
                    )
                    
                    # Initialize training components with improved parameters
                    criterion = DirectionalPredictionLoss(
                        direction_weight=config['direction_weight'],
                        magnitude_weight=config['magnitude_weight'],
                        short_penalty_multiplier=config['short_penalty_multiplier'],
                        bias_correction_weight=config['bias_correction_weight']
                    )
                    
                    # Set ticker for loss function
                    criterion.set_ticker(ticker)
                    
                    # Create optimizer and scheduler
                    optimizer = optim.Adam(
                        model.parameters(), 
                        lr=config['learning_rate'],
                        weight_decay=config['weight_decay']
                    )
                    
                    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                        optimizer, 
                        mode='min', 
                        factor=config['lr_factor'],
                        patience=config['lr_patience'],
                        verbose=True,
                        min_lr=1e-6
                    )
                    
                    # Track history for this model
                    model_history = TrainingHistory()
                    
                    # Train model
                    trained_model = train_model(
                        model=model,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        criterion=criterion,
                        optimizer=optimizer,
                        num_epochs=config['num_epochs'],
                        device=device,
                        training_history=model_history,
                        early_stopping_patience=config['early_stopping_patience'],
                        scheduler=scheduler,
                        ticker=ticker
                    )
                    
                    # Update main training history for plotting
                    if i == 0 or (model_history.validation_loss_history and 
                                 min(model_history.validation_loss_history) < 
                                 min(training_history.validation_loss_history or [float('inf')])):
                        training_history = model_history
                    
                    # Evaluate on validation set with range clipping
                    y_true_val, y_pred_val, current_prices_val = evaluate_model(
                        trained_model, 
                        val_loader, 
                        device,
                        return_predictions=True,
                        evaluator=evaluator,
                        ticker=ticker
                    )
                    
                    # Apply range clipping to predictions
                    max_pct_change = 0.015  # 1.5% maximum change per minute
                    min_bound = current_prices_val * (1 - max_pct_change)
                    max_bound = current_prices_val * (1 + max_pct_change)
                    y_pred_val = np.clip(y_pred_val, min_bound, max_bound)
                    
                    # Calculate validation metrics
                    val_metrics = evaluator.calculate_metrics(
                        y_true=y_true_val,
                        y_pred=y_pred_val,
                        current_prices=current_prices_val,
                        ticker=ticker
                    )
                    
                    # Get validation score for weighting
                    direction_accuracy = val_metrics.get('direction_accuracy', 0)
                    
                    # Save best individual model
                    if direction_accuracy > best_validation_score:
                        best_validation_score = direction_accuracy
                        best_model = trained_model
                    
                    # Add model to ensemble with more balanced weight
                    ensemble_models.append(trained_model)
                    
                    # Weight based on accuracy above random (50%)
                    model_weight = max(0.1, (direction_accuracy - 49) / 10)
                    ensemble_weights.append(model_weight)
                    
                    logger.info(f"Model {i+1} validation accuracy: {direction_accuracy:.2f}%, weight: {model_weight:.2f}")
                
                # Normalize weights
                total_weight = sum(ensemble_weights)
                ensemble_weights = [w / total_weight for w in ensemble_weights]
                
                # Create ensemble
                ensemble = EnsembleModel(ensemble_models, ensemble_weights)
                
                # Plot learning curves from best model's history
                logger.info("Plotting learning curves...")
                plot_learning_curves(training_history, ticker)
                plt.savefig(f"{ticker_dir}/learning_curves_{ticker}.png")
                
                # Evaluate ensemble on test set with range clipping
                logger.info("Evaluating ensemble model...")
                ensemble_metrics, y_true, y_pred, current_prices = ensemble.evaluate(
                    test_loader, 
                    device, 
                    evaluator=evaluator,
                    ticker=ticker
                )
                
                # Apply range clipping
                max_pct_change = 0.015  # 1.5% maximum change per minute
                min_bound = current_prices * (1 - max_pct_change)
                max_bound = current_prices * (1 + max_pct_change)
                y_pred = np.clip(y_pred, min_bound, max_bound)
                
                # Recalculate metrics after clipping
                evaluation_metrics = evaluator.calculate_metrics(
                    y_true=y_true,
                    y_pred=y_pred,
                    current_prices=current_prices,
                    ticker=ticker
                )
                
                # Plot prediction analysis
                plot_prediction_analysis(y_true, y_pred, current_prices, ticker)
                plt.savefig(f"{ticker_dir}/prediction_analysis_{ticker}.png")
                
                # Use best individual model for inference
                model = best_model
                
            else:
                # Standard single model approach for other tickers
                # Initialize model
                input_size = len(dataset.feature_names)
                model = LSTMModel(
                    input_size=input_size,
                    hidden_size=MODEL_PARAMS['hidden_size'],
                    num_layers=MODEL_PARAMS['num_layers'],
                    dropout=MODEL_PARAMS['dropout']
                )

                # Initialize training components with improved loss function
                criterion = DirectionalPredictionLoss(
                    direction_weight=loss_params['direction_weight'],
                    magnitude_weight=loss_params['magnitude_weight'],
                    short_penalty_multiplier=loss_params['short_penalty_multiplier'],
                    bias_correction_weight=loss_params['bias_correction_weight']
                )
                
                # Set ticker for loss function
                criterion.set_ticker(ticker)
                
                # Create optimizer with appropriate regularization
                optimizer = optim.Adam(
                    model.parameters(), 
                    lr=MODEL_PARAMS['learning_rate'],
                    weight_decay=MODEL_PARAMS.get('weight_decay', 1e-2)
                )
                
                # Create responsive scheduler
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, 
                    mode='min', 
                    factor=0.3,
                    patience=2,
                    verbose=True,
                    min_lr=1e-6
                )
                
                # Track training history
                training_history = TrainingHistory()
                
                # Train model
                logger.info("Starting model training...")
                print("Training model with enhanced direction-specific loss...")
                model = train_model(
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    criterion=criterion,
                    optimizer=optimizer,
                    num_epochs=MODEL_PARAMS['num_epochs'],
                    device=device,
                    training_history=training_history,
                    early_stopping_patience=MODEL_PARAMS['early_stopping_patience'],
                    scheduler=scheduler,
                    ticker=ticker
                )

                # Plot learning curves
                logger.info("Plotting learning curves...")
                plot_learning_curves(training_history, ticker)
                plt.savefig(f"{ticker_dir}/learning_curves_{ticker}.png")

                # Evaluate model with range clipping
                logger.info("Evaluating model...")
                y_true, y_pred, current_prices = evaluate_model(
                    model, 
                    test_loader, 
                    device,
                    return_predictions=True,
                    evaluator=evaluator,
                    ticker=ticker
                )
                
                # Apply range clipping
                max_pct_change = 0.015  # 1.5% maximum change per minute
                min_bound = current_prices * (1 - max_pct_change)
                max_bound = current_prices * (1 + max_pct_change)
                y_pred = np.clip(y_pred, min_bound, max_bound)
                
                # Calculate metrics after clipping
                evaluation_metrics = evaluator.calculate_metrics(
                    y_true=y_true,
                    y_pred=y_pred,
                    current_prices=current_prices,
                    ticker=ticker
                )
                
                # Plot prediction analysis
                plot_prediction_analysis(y_true, y_pred, current_prices, ticker)
                plt.savefig(f"{ticker_dir}/prediction_analysis_{ticker}.png")
            
            # Log evaluation metrics
            logger.info(f"Model Evaluation Metrics for {ticker}:")
            for metric, value in evaluation_metrics.items():
                logger.info(f"  {metric}: {value}")
            
            # Validate model quality before generating signals
            model_valid = validate_model_quality(evaluation_metrics, ticker)
            if not model_valid:
                print(f"Model for {ticker} failed quality validation. Skipping signal generation.")
                continue
            
            # Update signal generator with direction-specific accuracies
            signal_generator.update_direction_confidence(evaluation_metrics)
            
            # Generate trading signals using stock-specific signal generator
            logger.info("Generating trading signals...")
            
            # Use original (non-balanced) data for signal generation and backtesting
            stock_df_orig = enhance_features(stock_df, market_index_df)
            
            # Prepare test period for signal generation
            test_period = stock_df_orig.iloc[-len(y_true):]
            test_timestamps = test_period.index
            
            print(f"Generating signals for {len(test_timestamps)} timestamps...")
            
            # Reset monitor for clean signal generation
            monitor = MultiTickerMonitor(signal_generator)
            all_signals = []
            
            # Process test data in batches for signal generation
            batch_size = 1000
            num_batches = (len(test_timestamps) + batch_size - 1) // batch_size
            
            for batch_idx in range(num_batches):
                start_idx = batch_idx * batch_size
                end_idx = min((batch_idx + 1) * batch_size, len(test_timestamps))
                
                print(f"Processing batch {batch_idx+1}/{num_batches} (indices {start_idx}-{end_idx})...")
                
                for idx in range(start_idx, end_idx):
                    timestamp = test_timestamps[idx]
                    actual = y_true[idx - start_idx]
                    predicted = y_pred[idx - start_idx]
                    current_price = current_prices[idx - start_idx] if current_prices is not None else actual
                    
                    # Print progress update for large datasets
                    if (idx - start_idx) % 1000 == 0:
                        print(f"Processing index {idx}: Current price = {actual:.2f}, Predicted price = {predicted:.2f}")
                    
                    # Get market data window for signal generation
                    window_start_idx = max(0, stock_df_orig.index.get_indexer([timestamp])[0] - 20)
                    window_end_idx = stock_df_orig.index.get_indexer([timestamp])[0] + 1
                    market_data_window = stock_df_orig.iloc[window_start_idx:window_end_idx].copy()
                    
                    if len(market_data_window) == 0:
                        continue
                    
                    # Apply optimized bias correction if available
                    if ENABLE_OPTIMIZATION and 'bias_correction_params' in locals():
                        # Custom bias correction using optimized parameters
                        market_regime = signal_generator.detect_market_regime(market_data_window)
                        price_changes = market_data_window['close'].pct_change().dropna()
                        mean_price_change = price_changes[-20:].mean() if len(price_changes) >= 20 else 0
                        
                        if market_regime in ['trending_up', 'high_volatility']:
                            correction = current_price * bias_correction_params['downward_bias_correction_high']
                        else:
                            correction = current_price * bias_correction_params['downward_bias_correction_low']
                        
                        if mean_price_change > 0:
                            adaptive_factor = min(bias_correction_params['adaptive_factor_cap'], mean_price_change)
                            correction += current_price * adaptive_factor
                    else:
                        # Apply default correction based on evaluator
                        correction = evaluator.get_adaptive_correction(ticker) * 0.5  # Reduced by 50%
                    
                    corrected_prediction = predicted - correction
                    
                    # Apply range clipping
                    max_pct_change = 0.015  # 1.5% maximum change per minute
                    min_bound = current_price * (1 - max_pct_change)
                    max_bound = current_price * (1 + max_pct_change)
                    clipped_prediction = np.clip(corrected_prediction, min_bound, max_bound)
                    
                    try:
                        # Generate signal with proper error handling
                        signal = monitor.update_ticker(
                            ticker=ticker,
                            current_price=actual,
                            predicted_price=clipped_prediction,
                            timestamp=timestamp,
                            market_data=market_data_window,
                            prediction_metrics=evaluation_metrics
                        )
                        
                        if signal and signal['action']:
                            if (idx - start_idx) < 10 or (idx - start_idx) % 500 == 0:  # Limit logging for clarity
                                print(f"Generated signal at {timestamp}: {signal['action']}")
                            all_signals.append(signal)
                    except Exception as e:
                        logger.error(f"Error generating signal at {timestamp}: {e}")
                        # Continue processing despite errors
                        continue
            
            print(f"Generated {len(all_signals)} signals")
            
            # Add diagnostic information if no signals were generated
            if len(all_signals) == 0:
                print("\nDiagnostic information for signal generation:")
                # Check the first few test samples
                for idx in range(min(5, len(test_timestamps))):
                    timestamp = test_timestamps[idx]
                    actual = y_true[idx]
                    predicted = y_pred[idx]
                    price_change_pct = (predicted - actual) / actual
                    
                    # Get market data window
                    window_start_idx = max(0, stock_df_orig.index.get_indexer([timestamp])[0] - 20)
                    window_end_idx = stock_df_orig.index.get_indexer([timestamp])[0] + 1
                    market_data_window = stock_df_orig.iloc[window_start_idx:window_end_idx].copy()
                    
                    # Check trend condition
                    if 'MA5' not in market_data_window.columns:
                        market_data_window['MA5'] = market_data_window['close'].rolling(5).mean()
                    if 'SMA_20' not in market_data_window.columns:
                        market_data_window['SMA_20'] = market_data_window['close'].rolling(20).mean()
                    
                    ma5 = market_data_window['MA5'].iloc[-1] if not market_data_window['MA5'].empty else None
                    ma20 = market_data_window['SMA_20'].iloc[-1] if not market_data_window['SMA_20'].empty else None
                    
                    if ma5 is not None and ma20 is not None:
                        trend_strength = abs((ma5 - ma20) / ma20)
                        trend_threshold = signal_generator.params['trend_threshold']
                        trend_ok = trend_strength <= trend_threshold
                    else:
                        trend_strength = None
                        trend_threshold = signal_generator.params['trend_threshold']
                        trend_ok = False
                    
                    # Check volume condition
                    if 'volume' in market_data_window.columns and not market_data_window['volume'].empty:
                        current_volume = market_data_window['volume'].iloc[-1]
                        avg_volume = market_data_window['volume'].rolling(20).mean().iloc[-1]
                        volume_ratio = current_volume / avg_volume if not pd.isna(avg_volume) and avg_volume > 0 else 0
                        volume_threshold = signal_generator.params['volume_threshold']
                        volume_ok = volume_ratio >= volume_threshold
                    else:
                        volume_ratio = None
                        volume_threshold = signal_generator.params['volume_threshold']
                        volume_ok = False
                    
                    # Check signal thresholds
                    thresholds = signal_generator.calculate_thresholds(
                        signal_generator.current_atr or 0.001,
                        'neutral',
                        signal_generator.direction_confidence
                    )
                    
                    print(f"\nTimestamp {idx}: {timestamp}")
                    print(f"  Current price: {actual:.2f}, Predicted: {predicted:.2f}")
                    print(f"  Price change %: {price_change_pct*100:.4f}% (need ±{thresholds['long_entry']*100:.4f}% for long, ±{thresholds['short_entry']*100:.4f}% for short)")
                    print(f"  Trend check: {'PASS' if trend_ok else 'FAIL'} (strength: {trend_strength:.4f if trend_strength is not None else 'N/A'}, threshold: {trend_threshold:.4f})")
                    print(f"  Volume check: {'PASS' if volume_ok else 'FAIL'} (ratio: {volume_ratio:.2f if volume_ratio is not None else 'N/A'}, threshold: {volume_threshold:.2f})")
                    print(f"  Signal would be generated: {'Yes' if abs(price_change_pct) > thresholds['long_entry'] and trend_ok and volume_ok else 'No'}")
            
            # Perform backtesting with realistic transaction costs
            try:
                trades, total_commission, total_slippage = backtest_trades_with_costs(
                    all_signals, 
                    stock_df_orig,
                    commission_rate=TRANSACTION_COSTS['commission_rate'],
                    slippage_factor=TRANSACTION_COSTS['slippage_factor']
                )
                print(f"Generated {len(trades)} trades")
                
                # Calculate trade performance metrics
                trade_metrics = evaluator.calculate_trade_performance_metrics(
                    trades=trades,
                    initial_capital=10000.0,
                    include_transaction_costs=True
                )
                
                # Log costs
                logger.info(f"Total commission: ${total_commission:.2f}, Total slippage: ${total_slippage:.2f}")
                
                # Print performance summary
                print("\nTrading Performance Summary:")
                print(f"Total Trades: {trade_metrics['total_trades']}")
                print(f"Win Rate: {trade_metrics['win_rate']:.2f}%")
                print(f"Long Trades: {trade_metrics.get('long_trades', 0)}, Short Trades: {trade_metrics.get('short_trades', 0)}")
                print(f"Long Win Rate: {trade_metrics.get('long_win_rate', 0):.2f}%, Short Win Rate: {trade_metrics.get('short_win_rate', 0):.2f}%")
                print(f"Total Profit: ${trade_metrics['total_profit']:.2f}")
                print(f"Maximum Drawdown: ${trade_metrics['maximum_drawdown']:.2f}")
                print(f"Sharpe Ratio: {trade_metrics['sharpe_ratio']:.2f}")
                print(f"Transaction Costs: ${total_commission + total_slippage:.2f}")
                
                # Generate visualizations
                plot_trading_metrics(trade_metrics, ticker)
                plt.savefig(f"{ticker_dir}/trading_metrics_{ticker}.png")
                
                plot_candlestick_analysis(stock_df_orig, signals=all_signals, trades=trades, ticker=ticker)
                plt.savefig(f"{ticker_dir}/candlestick_analysis_{ticker}.png")
                
            except Exception as e:
                logger.error(f"Error in backtesting for {ticker}: {e}")
                traceback.print_exc()
            
            # Save model and artifacts
            try:
                model_filename = f'{ticker_dir}/lstm_model_{ticker}_{end_date_str}.pth'
                scaler_filename = f'{ticker_dir}/scaler_{ticker}_{end_date_str}.pkl'
                history_filename = f'{ticker_dir}/training_history_{ticker}_{end_date_str}.pkl'
                metrics_filename = f'{ticker_dir}/performance_metrics_{ticker}_{end_date_str}.pkl'
                signals_filename = f'{ticker_dir}/signals_{ticker}_{end_date_str}.pkl'
                
                # Save model
                torch.save(model.state_dict(), model_filename)
                
                # Save scaler
                with open(scaler_filename, 'wb') as f:
                    pickle.dump(dataset.scaler, f)
                    
                # Save training history
                with open(history_filename, 'wb') as f:
                    pickle.dump(training_history, f)
                
                # Save signals and trades
                with open(signals_filename, 'wb') as f:
                    signals_data = {
                        'signals': all_signals,
                        'trades': trades if 'trades' in locals() else []
                    }
                    pickle.dump(signals_data, f)
                    
                # Save performance metrics
                with open(metrics_filename, 'wb') as f:
                    metrics_data = {
                        'evaluation_metrics': evaluation_metrics,
                        'trade_metrics': trade_metrics if 'trade_metrics' in locals() else {},
                        'transaction_costs': {
                            'commission': total_commission if 'total_commission' in locals() else 0,
                            'slippage': total_slippage if 'total_slippage' in locals() else 0
                        }
                    }
                    pickle.dump(metrics_data, f)
                    
                print(f"\nSaved model and artifacts to {ticker_dir}")
                
            except Exception as e:
                logger.error(f"Error saving artifacts for {ticker}: {e}")
                traceback.print_exc()
            
            logger.info(f"Completed processing for {ticker}")
            
        except Exception as e:
            logger.error(f"Error processing {ticker}: {e}")
            traceback.print_exc()
            continue
            
        finally:
            # Clean up memory
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    logger.info("Completed processing all tickers")
    print("\nTrading model training pipeline completed!")

if __name__ == "__main__":
   main()        

Fetching SPY: 48068rows [00:00, 77752.28rows/s]



Processing MSFT...


Fetching MSFT: 38094rows [00:02, 16268.63rows/s]


Optimizing trading parameters for MSFT...


[I 2025-03-17 12:59:50,741] A new study created in memory with name: no-name-aed55e7c-c529-48c1-af7a-d7b777d97a3e
