In [None]:
# Essential Imports
import requests
import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import mplfinance as mpf
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import ta
import logging
import gc
import traceback
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Tuple
from sklearn.metrics import mean_squared_error, mean_absolute_error
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import os
from collections import defaultdict
import matplotlib.gridspec as gridspec
from ta.volatility import BollingerBands
from ta.trend import MACD
import warnings
from dotenv import load_dotenv
import torch
import torch.nn as nn
import torch.nn.functional as F  # This import was missing
import mplfinance

# Suppress mplfinance warnings for too much data
warnings.filterwarnings("ignore", category=UserWarning, module="mplfinance")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    filename='structured_prediction_hft.log',
    filemode='a',
    format='%(asctime)s:%(levelname)s:%(message)s'
)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# API Configuration
API_KEY = os.getenv("POLYGON_API_KEY")

@dataclass
class TrainingHistory:
    """Tracks training metrics during model training."""
    loss_history: List[float] = field(default_factory=list)
    validation_loss_history: List[float] = field(default_factory=list)

    def update(self, epoch: int, loss: float, val_loss: float):
        """Update training history with new metrics."""
        self.loss_history.append(loss)
        self.validation_loss_history.append(val_loss)

class ModelEvaluator:
    def __init__(self, prediction_threshold: float = 0.001):
        self.prediction_threshold = prediction_threshold
        self.mse_scores = []
        self.mae_scores = []
        self.seasonal_errors = []

    def calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
        """Calculate various error metrics based on predictions."""
        metrics = {}
        
        # Ensure inputs are 1-D and matching length
        y_true = y_true.flatten()
        y_pred = y_pred.flatten()
        
        if len(y_true) != len(y_pred):
            raise ValueError("y_true and y_pred must have the same length")
            
        # Handle invalid values
        if np.isnan(y_true).any() or np.isnan(y_pred).any():
            raise ValueError("Inputs contain NaN values")
        if np.isinf(y_true).any() or np.isinf(y_pred).any():
            raise ValueError("Inputs contain infinite values")
            
        try:
            # Direction Accuracy
            direction_true = np.sign(y_true[1:] - y_true[:-1])
            direction_pred = np.sign(y_pred[1:] - y_pred[:-1])
            direction_accuracy = np.mean(direction_true == direction_pred) * 100
            metrics['direction_accuracy'] = direction_accuracy
            
            # Magnitude Correlation
            magnitude_correlation = np.corrcoef(y_true, y_pred)[0, 1]
            metrics['magnitude_correlation'] = magnitude_correlation
            
            # Timing Accuracy
            trend_changes_true = np.where(direction_true != 0)[0]
            trend_changes_pred = np.where(direction_pred != 0)[0]
            if len(trend_changes_true) > 0:
                timing_accuracy = len(set(trend_changes_true).intersection(set(trend_changes_pred))) / len(trend_changes_true) * 100
            else:
                timing_accuracy = 0.0
            metrics['timing_accuracy'] = timing_accuracy
            
        except Exception as e:
            logger.error(f"Error in calculate_metrics: {e}")
            raise
            
        return metrics

    def calculate_trade_performance_metrics(self, trades: List[Dict], initial_capital: float = 10000.0) -> Dict[str, float]:
        """Calculate comprehensive performance metrics for trades."""
        if not trades:
            return {
                'total_trades': 0,
                'total_profit': 0.0,
                'win_rate': 0.0,
                'loss_rate': 0.0,
                'average_profit_per_trade': 0.0,
                'maximum_drawdown': 0.0,
                'profit_factor': 0.0,
                'sharpe_ratio': 0.0
            }

        profits = []
        daily_returns = defaultdict(float)

        # Calculate trade results
        for trade in trades:
            if trade['position'] == 'enter_long':
                profit = trade['exit_price'] - trade['entry_price']
            else:  # enter_short
                profit = trade['entry_price'] - trade['exit_price']
                
            profits.append(profit)
            trade_date = trade['exit_time'].date()
            daily_returns[trade_date] += profit

        # Calculate metrics
        total_profit = sum(profits)
        winning_trades = [p for p in profits if p > 0]
        losing_trades = [p for p in profits if p < 0]
        
        win_rate = len(winning_trades) / len(profits) * 100 if profits else 0
        loss_rate = len(losing_trades) / len(profits) * 100 if profits else 0
        
        avg_win = np.mean(winning_trades) if winning_trades else 0
        avg_loss = abs(np.mean(losing_trades)) if losing_trades else 0
        
        # Calculate daily metrics
        daily_returns_list = list(daily_returns.values())
        sharpe_ratio = (np.mean(daily_returns_list) / np.std(daily_returns_list) * np.sqrt(252)
                       if len(daily_returns_list) > 1 else 0)

        return {
            'total_trades': len(trades),
            'total_profit': round(total_profit, 2),
            'win_rate': round(win_rate, 2),
            'loss_rate': round(loss_rate, 2),
            'average_profit_per_trade': round(np.mean(profits), 2),
            'maximum_drawdown': round(abs(min(np.minimum.accumulate(np.cumsum(profits)))), 2),
            'profit_factor': round(sum(winning_trades) / abs(sum(losing_trades)) if losing_trades else float('inf'), 2),
            'sharpe_ratio': round(sharpe_ratio, 2)
        }

class AdaptiveSignalGenerator:
    def __init__(self, ticker: str):
        self.ticker = ticker
        self.stock_profiles = {
            'MSFT': {
                'base_entry_threshold': 0.0008,
                'base_exit_threshold': 0.0006,
                'base_stop_loss': 0.0015,
                'atr_multiplier': 1.2,
                'min_hold_time': 5,
                'max_daily_trades': 20,
                'position_size': 0.2,
                'volatility_threshold': 0.8,
                'volume_threshold': 1.0,
                'trend_threshold': 0.02
            },
            'GOOGL': {
                'base_entry_threshold': 0.001,
                'base_exit_threshold': 0.0007,
                'base_stop_loss': 0.002,
                'atr_multiplier': 1.3,
                'min_hold_time': 3,
                'max_daily_trades': 25,
                'position_size': 0.2,
                'volatility_threshold': 0.9,
                'volume_threshold': 1.0,
                'trend_threshold': 0.02
            },
            'TSLA': {
                'base_entry_threshold': 0.003,
                'base_exit_threshold': 0.002,
                'base_stop_loss': 0.004,
                'atr_multiplier': 2.0,
                'min_hold_time': 2,
                'max_daily_trades': 15,
                'position_size': 0.1,
                'volatility_threshold': 1.5,
                'volume_threshold': 1.5,
                'trend_threshold': 0.03
            },
            'NVDA': {
                'base_entry_threshold': 0.0025,
                'base_exit_threshold': 0.0018,
                'base_stop_loss': 0.0035,
                'atr_multiplier': 1.8,
                'min_hold_time': 2,
                'max_daily_trades': 18,
                'position_size': 0.12,
                'volatility_threshold': 1.3,
                'volume_threshold': 1.3,
                'trend_threshold': 0.025
            },
            'TQQQ': {
                'base_entry_threshold': 0.005,
                'base_exit_threshold': 0.004,
                'base_stop_loss': 0.007,
                'atr_multiplier': 3.0,
                'min_hold_time': 2,
                'max_daily_trades': 10,
                'position_size': 0.06,
                'volatility_threshold': 2.5,
                'volume_threshold': 2.0,
                'trend_threshold': 0.05
            },
            'SQQQ': {
                'base_entry_threshold': 0.005,
                'base_exit_threshold': 0.004,
                'base_stop_loss': 0.007,
                'atr_multiplier': 3.0,
                'min_hold_time': 2,
                'max_daily_trades': 8,
                'position_size': 0.06,
                'volatility_threshold': 2.5,
                'volume_threshold': 2.2,
                'trend_threshold': 0.05
            },
            'QLD': {
                'base_entry_threshold': 0.003,
                'base_exit_threshold': 0.002,
                'base_stop_loss': 0.004,
                'atr_multiplier': 2.0,
                'min_hold_time': 2,
                'max_daily_trades': 15,
                'position_size': 0.10,
                'volatility_threshold': 1.7,
                'volume_threshold': 1.5,
                'trend_threshold': 0.035
            },
            'PSQ': {
                'base_entry_threshold': 0.003,
                'base_exit_threshold': 0.002,
                'base_stop_loss': 0.004,
                'atr_multiplier': 1.8,
                'min_hold_time': 4,
                'max_daily_trades': 15,
                'position_size': 0.15,
                'volatility_threshold': 1.4,
                'volume_threshold': 1.4,
                'trend_threshold': 0.03
            }
        }
        
        self.params = self.stock_profiles[ticker]
        self._initialize_parameters()
        
        
    def _initialize_parameters(self):
        """Initialize trading parameters from stock profile."""
        self.base_entry_threshold = self.params['base_entry_threshold']
        self.base_exit_threshold = self.params['base_exit_threshold']
        self.base_stop_loss = self.params['base_stop_loss']
        self.atr_multiplier = self.params['atr_multiplier']
        self.min_hold_time = self.params['min_hold_time']
        self.volatility_threshold = self.params['volatility_threshold']
        
        self.position = None
        self.entry_price = None
        self.entry_time = None
        self.current_atr = None
        self.daily_trades = 0
        self.last_trade_date = None
        
    def calculate_atr(self, market_data: pd.DataFrame) -> float:
        """Calculate Average True Range."""
        high = market_data['high'].values
        low = market_data['low'].values
        close = pd.Series(market_data['close'].values)
        
        tr1 = pd.Series(high - low)
        tr2 = pd.Series(abs(high - close.shift()))
        tr3 = pd.Series(abs(low - close.shift()))
        
        tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
        atr = tr.rolling(window=14).mean().iloc[-1]
        
        return atr
        
    def calculate_thresholds(self, current_volatility: float) -> Tuple[float, float, float]:
        """Calculate adaptive thresholds based on current volatility."""
        volatility_factor = current_volatility * self.atr_multiplier
        
        entry_threshold = self.base_entry_threshold * (1 + volatility_factor)
        exit_threshold = self.base_exit_threshold * (1 + volatility_factor)
        stop_loss = self.base_stop_loss * (1 + volatility_factor)
        
        return entry_threshold, exit_threshold, stop_loss
        
    def check_trend_condition(self, market_data: pd.DataFrame) -> bool:
        """Check if trend conditions are favorable."""
        ma5 = market_data['MA5'].iloc[-1]
        ma20 = market_data['SMA_20'].iloc[-1]
        
        trend_strength = abs((ma5 - ma20) / ma20)
        
        if self.ticker in ['TSLA', 'NVDA']:
            return trend_strength <= self.params['trend_threshold'] * 1.5
        else:
            return trend_strength <= self.params['trend_threshold']
            
    def check_volume_condition(self, market_data: pd.DataFrame) -> bool:
        """Check if volume conditions are favorable."""
        current_volume = market_data['volume'].iloc[-1]
        avg_volume = market_data['volume'].rolling(20).mean().iloc[-1]
        volume_ratio = current_volume / avg_volume
        
        return volume_ratio >= self.params['volume_threshold']
        
    def generate_signals(self, ticker: str, current_price: float, 
                        predicted_price: float, timestamp: pd.Timestamp,
                        market_data: pd.DataFrame,
                        order_book_data: Dict = None) -> Dict:
        """Generate trading signals with stock-specific adaptations."""
        # Reset daily trades if new day
        current_date = timestamp.date()
        if self.last_trade_date != current_date:
            self.daily_trades = 0
            self.last_trade_date = current_date
            
        # Check trade frequency limit
        if self.daily_trades >= self.params['max_daily_trades']:
            return self._create_signal(ticker, timestamp, current_price)
            
        # Check minimum hold time
        if self.position and self.entry_time:
            hold_time = (timestamp - self.entry_time).total_seconds() / 60
            if hold_time < self.params['min_hold_time']:
                return self._create_signal(ticker, timestamp, current_price)
        
        try:
            self.current_atr = self.calculate_atr(market_data)
            entry_threshold, exit_threshold, stop_loss = self.calculate_thresholds(self.current_atr)
        except Exception as e:
            logger.error(f"Error calculating thresholds: {e}")
            return self._create_signal(ticker, timestamp, current_price)
        
        price_change_pct = (predicted_price - current_price) / current_price
        signal = self._create_signal(ticker, timestamp, current_price)
        
        if self.position is None:
            trend_ok = self.check_trend_condition(market_data)
            volume_ok = self.check_volume_condition(market_data)
            
            if abs(price_change_pct) > entry_threshold and trend_ok and volume_ok:
                if price_change_pct > 0:
                    self.position = 'long'
                    self.entry_price = current_price
                    self.entry_time = timestamp
                    signal['action'] = 'enter_long'
                    signal['position'] = 'long'
                    self.daily_trades += 1
                else:
                    self.position = 'short'
                    self.entry_price = current_price
                    self.entry_time = timestamp
                    signal['action'] = 'enter_short'
                    signal['position'] = 'short'
                    self.daily_trades += 1
        else:
            if self.position == 'long':
                price_change_from_entry = (current_price - self.entry_price) / self.entry_price
                if (price_change_from_entry <= -stop_loss or 
                    price_change_from_entry >= exit_threshold):
                    self.position = None
                    self.entry_price = None
                    self.entry_time = None
                    signal['action'] = 'exit_long'
                    signal['position'] = 'long'
            elif self.position == 'short':
                price_change_from_entry = (self.entry_price - current_price) / self.entry_price
                if (price_change_from_entry <= -stop_loss or 
                    price_change_from_entry >= exit_threshold):
                    self.position = None
                    self.entry_price = None
                    self.entry_time = None
                    signal['action'] = 'exit_short'
                    signal['position'] = 'short'
        
        return signal
        
    def _create_signal(self, ticker: str, timestamp: pd.Timestamp, price: float) -> Dict:
        """Create a base signal dictionary."""
        return {
            'ticker': ticker,
            'timestamp': timestamp,
            'price': price,
            'action': None,
            'position': None,
            'thresholds': {
                'entry': self.base_entry_threshold,
                'exit': self.base_exit_threshold,
                'stop_loss': self.base_stop_loss
            },
            'market_conditions': {
                'atr': self.current_atr,
                'daily_trades': self.daily_trades
            }
        }

class MultiTickerMonitor:
    def __init__(self, signal_generator: AdaptiveSignalGenerator):
        self.tracked_tickers = {}
        self.signal_generator = signal_generator

    def add_ticker(self, ticker: str, initial_price: float):
        """Add a new ticker to monitor."""
        self.tracked_tickers[ticker] = {
            'current_price': initial_price,
            'signals': [],
            'last_update': datetime.now()
        }

    def update_ticker(self, ticker: str, current_price: float,
                     predicted_price: float, timestamp: pd.Timestamp,
                     market_data: pd.DataFrame,
                     order_book_data: Dict = None) -> Optional[Dict]:
        """Update ticker information and generate signals."""
        if ticker not in self.tracked_tickers:
            self.add_ticker(ticker, current_price)
    
        self.tracked_tickers[ticker]['current_price'] = current_price
        self.tracked_tickers[ticker]['last_update'] = timestamp
    
        signal = self.signal_generator.generate_signals(
            ticker=ticker,
            current_price=current_price,
            predicted_price=predicted_price,
            timestamp=timestamp,
            market_data=market_data,
            order_book_data=order_book_data
        )
    
        if signal and signal['action']:
            self.tracked_tickers[ticker]['signals'].append(signal)
            return signal
    
        return None

    def get_signals(self, ticker: str) -> List[Dict]:
        """Retrieve signals for a specific ticker."""
        return self.tracked_tickers.get(ticker, {}).get('signals', [])

class PolygonDataFetcher:
    """Handles data retrieval from the Polygon.io API with robust error handling and rate limiting."""
    
    BASE_URL = "https://api.polygon.io/v2"

    def __init__(self, api_key: str):
        self.api_key = api_key
        self.session = requests.Session()
        self.session.headers.update({"Authorization": f"Bearer {api_key}"})
        self.validation_metrics = []

    def _make_request(self, endpoint: str, params: Optional[Dict] = None, retries: int = 3) -> Dict:
        """Make an API request with retry logic and rate limiting."""
        url = f"{self.BASE_URL}/{endpoint}"
        
        for attempt in range(retries):
            try:
                response = self.session.get(url, params=params, timeout=30)
                
                if response.status_code == 429:  # Rate limit exceeded
                    wait_time = min(60 * (attempt + 1), 300)  # Max 5 minutes
                    logger.warning(f"Rate limit exceeded. Waiting {wait_time} seconds...")
                    time.sleep(wait_time)
                    continue
                    
                response.raise_for_status()
                data = response.json()
                
                logger.debug(f"API response keys: {data.keys()}")
                if 'results' in data and len(data['results']) > 0:
                    logger.debug(f"Sample result: {data['results'][0]}")
                
                return data
                
            except Exception as e:
                if attempt == retries - 1:
                    logger.error(f"Request failed after {retries} attempts: {e}")
                    raise
                logger.warning(f"Request failed (Attempt {attempt + 1}/{retries}): {e}. Retrying in 30 seconds...")
                time.sleep(30)
                
        raise Exception(f"Failed after {retries} attempts")

    def fetch_stock_data(self, symbol: str, start_date: str, end_date: str, 
                        timespan: str = "minute", multiplier: int = 1) -> pd.DataFrame:
        """
        Fetch aggregate stock data for a specific ticker and date range.
        
        Args:
            symbol: Stock ticker symbol
            start_date: Start date in YYYY-MM-DD format
            end_date: End date in YYYY-MM-DD format
            timespan: Timespan of the aggregates ('minute', 'hour', 'day')
            multiplier: The size of the timespan multiplier
            
        Returns:
            DataFrame containing aggregated stock data
        """
        endpoint = f"aggs/ticker/{symbol}/range/{multiplier}/{timespan}/{start_date}/{end_date}"
        params = {"adjusted": "true", "sort": "asc", "limit": 50000}
        all_results = []

        with tqdm(desc=f"Fetching {symbol}", unit='rows') as pbar:
            while True:
                data = self._make_request(endpoint, params)
                
                if 'results' in data:
                    if len(data['results']) > 0 and 'c' not in data['results'][0]:
                        logger.error(f"Data for {symbol} does not contain 'c' (close) field.")
                        return pd.DataFrame()

                    all_results.extend(data['results'])
                    pbar.update(len(data['results']))
                    
                if len(data.get('results', [])) < 50000:
                    break
                    
                # Update start_date for next request
                last_timestamp = data['results'][-1]['t']
                start_date = datetime.fromtimestamp(last_timestamp / 1000).strftime('%Y-%m-%d')
                params['from'] = start_date

        # Create DataFrame and process data
        df = pd.DataFrame(all_results)
        
        if 't' in df.columns:
            # Remove duplicate timestamps
            duplicates_before = len(df) - df.drop_duplicates(subset=['t'], keep='last').shape[0]
            if duplicates_before > 0:
                logger.info(f"Removed {duplicates_before} duplicate timestamps for {symbol}.")
            df = df.drop_duplicates(subset=['t'], keep='last')
        else:
            logger.error("'t' column not found in fetched data.")
            return pd.DataFrame()

        # Rename columns for clarity
        df = df.rename(columns={
            't': 'timestamp',
            'o': 'open',
            'h': 'high',
            'l': 'low',
            'c': 'close',
            'v': 'volume',
            'n': 'transactions'
        })

        # Convert timestamp to datetime and set as index
        df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
        df.set_index('timestamp', inplace=True)

        return df

class ImprovedHFTDataset(Dataset):
    """Enhanced Dataset for High-Frequency Trading data preparation."""

    def __init__(self, data, sequence_length, predict_pct_change=True, price_scaler=None, feature_scaler=None):
        self.sequence_length = sequence_length
        self.predict_pct_change = predict_pct_change
        
        # Ensure data is a DataFrame
        if not isinstance(data, pd.DataFrame):
            raise ValueError("Input data must be a pandas DataFrame")
            
        # Calculate additional technical features
        self._add_technical_features(data)
        
        # Separate price data from other features
        price_cols = ['close']
        
        # Define expanded feature columns
        feature_cols = [
            'volume', 'SMA_20', 'EMA_20', 'RSI', 'MA5',
            'Bollinger_High', 'Bollinger_Low', 'MACD', 'MACD_Signal',
            'volatility_atr', 'trend_adx', 'momentum_roc', 'momentum_kama',
            'volume_cmf', 'volume_em', 'volume_sma_ratio', 'price_distance_from_ma',
            'volatility_bbw', 'momentum_stoch', 'trend_cci', 'trend_ichimoku_a'
        ]
        
        # Ensure all feature columns exist
        available_features = [col for col in feature_cols if col in data.columns]
        
        # Initialize scalers if not provided
        self.price_scaler = price_scaler if price_scaler is not None else StandardScaler()
        self.feature_scaler = feature_scaler if feature_scaler is not None else StandardScaler()
        
        # Prepare features and targets
        price_data = data[price_cols].values
        feature_data = data[available_features].values
        
        # Scale price data
        if not hasattr(self.price_scaler, 'mean_'):
            self.normalized_prices = self.price_scaler.fit_transform(price_data)
        else:
            self.normalized_prices = self.price_scaler.transform(price_data)
            
        # Scale feature data
        if not hasattr(self.feature_scaler, 'mean_'):
            self.normalized_features = self.feature_scaler.fit_transform(feature_data)
        else:
            self.normalized_features = self.feature_scaler.transform(feature_data)
            
        # Combine normalized data
        self.features = np.hstack((self.normalized_prices, self.normalized_features))
        
        # Create target values
        if self.predict_pct_change:
            # Predict percentage change in price
            pct_change = data['close'].pct_change(1).shift(-1).values
            self.labels = pct_change[:-1]  # Remove last row (NaN value)
            self.features = self.features[:-1]  # Align with labels
        else:
            # Predict actual next price
            self.labels = data['close'].shift(-1).values[:-1]
            self.features = self.features[:-1]

    def _add_technical_features(self, data):
        """Add advanced technical indicators to the dataset."""
        # Check if data has required columns
        if not all(col in data.columns for col in ['open', 'high', 'low', 'close', 'volume']):
            raise ValueError("Data must contain OHLCV columns")
            
        # Make a copy to avoid modifying the original
        df = data.copy()
        
        # Volatility indicators
        df['volatility_atr'] = ta.volatility.average_true_range(df['high'], df['low'], df['close'])
        df['volatility_bbw'] = ta.volatility.bollinger_pband(df['close'])
        
        # Trend indicators
        df['trend_adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
        df['trend_cci'] = ta.trend.cci(df['high'], df['low'], df['close'])
        df['trend_ichimoku_a'] = ta.trend.ichimoku_a(df['high'], df['low'])
        
        # Momentum indicators
        df['momentum_roc'] = ta.momentum.roc(df['close'])
        df['momentum_kama'] = ta.momentum.kama(df['close'])
        df['momentum_stoch'] = ta.momentum.stoch(df['high'], df['low'], df['close'])
        
        # Volume indicators
        df['volume_cmf'] = ta.volume.chaikin_money_flow(df['high'], df['low'], df['close'], df['volume'])
        df['volume_em'] = ta.volume.ease_of_movement(df['high'], df['low'], df['close'], df['volume'])
        df['volume_sma_ratio'] = df['volume'] / df['volume'].rolling(20).mean()
        
        # Price relative to moving average
        df['price_distance_from_ma'] = (df['close'] - df['SMA_20']) / df['SMA_20']
        
        # Fill NaN values
        df.fillna(method='bfill', inplace=True)
        df.fillna(0, inplace=True)
        
        # Update the original data with new features
        for col in df.columns:
            if col not in data.columns:
                data[col] = df[col]

    def __len__(self):
        """Return the total number of sequences in the dataset."""
        return max(0, len(self.features) - self.sequence_length)

    def __getitem__(self, idx):
        """Get a single sequence and its corresponding label."""
        x = self.features[idx:idx + self.sequence_length]
        y = self.labels[idx + self.sequence_length - 1]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32).view(-1)

class DirectionalMSELoss(nn.Module):
    def __init__(self, direction_weight=1.5, magnitude_weight=1.0):
        super(DirectionalMSELoss, self).__init__()
        self.direction_weight = direction_weight
        self.magnitude_weight = magnitude_weight
        self.mse = nn.MSELoss()
        
    def forward(self, y_pred, y_true):
        # Standard MSE for magnitude accuracy
        magnitude_loss = self.mse(y_pred, y_true)
        
        # Get batch size for reshaping if needed
        if len(y_pred.shape) == 1:
            y_pred = y_pred.unsqueeze(1)
        if len(y_true.shape) == 1:
            y_true = y_true.unsqueeze(1)
            
        # Calculate sign agreement loss for direction accuracy
        pred_sign = torch.sign(y_pred)
        true_sign = torch.sign(y_true)
        
        # Binary direction loss (1 if direction is wrong, 0 if correct)
        direction_loss = torch.mean((pred_sign != true_sign).float())
        
        # Combined loss with weighting
        total_loss = (self.magnitude_weight * magnitude_loss) + (self.direction_weight * direction_loss)
        
        return total_loss


class EnhancedLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size=1, dropout=0.2):
        super(EnhancedLSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Bidirectional LSTM for better pattern recognition
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Attention mechanism
        self.attention = nn.Linear(hidden_size*2, 1)
        
        # Output layers
        self.fc1 = nn.Linear(hidden_size*2, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size = x.size(0)
        
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers*2, batch_size, self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers*2, batch_size, self.hidden_size).to(x.device)
        
        # Forward propagate LSTM
        lstm_out, _ = self.lstm(x, (h0, c0))
        
        # Apply attention
        attention_weights = F.softmax(self.attention(lstm_out), dim=1)
        context_vector = attention_weights * lstm_out
        context_vector = torch.sum(context_vector, dim=1)
        
        # Process through fully connected layers
        out = self.fc1(context_vector)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc2(out)
        
        return out

def train_enhanced_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device,
                        patience=10, min_delta=0.001, learning_rate_scheduler=None):
    """
    Train the LSTM model with comprehensive tracking and early stopping.
    
    Args:
        model: The neural network model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: Loss function
        optimizer: Optimization algorithm
        num_epochs: Maximum number of training epochs
        device: Computation device (CPU/GPU)
        patience: Number of epochs to wait for improvement before stopping
        min_delta: Minimum change in validation loss to be considered as improvement
        learning_rate_scheduler: Optional scheduler for adjusting learning rate
        
    Returns:
        Trained model and training history
    """
    model.to(device)
    
    # Initialize training history tracking
    history = {
        'train_loss': [],
        'val_loss': [],
        'learning_rates': [],
        'best_epoch': 0
    }
    
    # Early stopping variables
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(1, num_epochs + 1):
        # Training phase
        model.train()
        train_losses = []
        
        # Progress bar for training
        train_progress = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]")
        
        for batch_x, batch_y in train_progress:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            
            # Backward pass and optimization
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Record loss
            train_losses.append(loss.item())
            train_progress.set_postfix({"loss": f"{loss.item():.6f}"})

        # Calculate average training loss
        avg_train_loss = np.mean(train_losses)
        history['train_loss'].append(avg_train_loss)
        
        # Record current learning rate
        if learning_rate_scheduler:
            current_lr = optimizer.param_groups[0]['lr']
            history['learning_rates'].append(current_lr)

        # Validation phase
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            val_progress = tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Valid]")
            
            for batch_x, batch_y in val_progress:
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                # Forward pass
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                
                # Record loss
                val_losses.append(loss.item())
                val_progress.set_postfix({"loss": f"{loss.item():.6f}"})

        # Calculate average validation loss
        avg_val_loss = np.mean(val_losses)
        history['val_loss'].append(avg_val_loss)

        # Print epoch summary
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train Loss = {avg_train_loss:.6f}, "
              f"Val Loss = {avg_val_loss:.6f}")

        # Check for improvement
        if avg_val_loss < best_val_loss - min_delta:
            # Save the model state dict (not the entire model for memory efficiency)
            best_model_state = model.state_dict().copy()
            best_val_loss = avg_val_loss
            patience_counter = 0
            history['best_epoch'] = epoch
            print(f"* Validation loss improved to {best_val_loss:.6f}")
        else:
            patience_counter += 1
            print(f"* Validation loss did not improve. Patience: {patience_counter}/{patience}")
            
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break
        
        # Step the learning rate scheduler if provided
        if learning_rate_scheduler:
            if isinstance(learning_rate_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                learning_rate_scheduler.step(avg_val_loss)
            else:
                learning_rate_scheduler.step()

    # Load the best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model from epoch {history['best_epoch']}")
    
    return model, history

def evaluate_model(model: nn.Module, 
                  test_loader: DataLoader, 
                  device: torch.device) -> Tuple[np.ndarray, np.ndarray]:
    """
    Evaluate the trained model on test data.
    
    Args:
        model: Trained LSTM model
        test_loader: DataLoader for test data
        device: Computation device
        
    Returns:
        Tuple of true and predicted values
    """
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = batch_x.to(device)
            outputs = model(batch_x).cpu().numpy()
            y_true.extend(batch_y.numpy())
            y_pred.extend(outputs.flatten())

    return np.array(y_true), np.array(y_pred)

def backtest_trades(signals: List[Dict], df: pd.DataFrame) -> List[Dict]:
    """
    Execute trades based on generated signals and record trade details.
    
    Args:
        signals: List of trading signals
        df: DataFrame with market data
        
    Returns:
        List of executed trades with details
    """
    trades = []
    entry_signal = None
    
    if not isinstance(df.index, pd.DatetimeIndex):
        df.index = pd.to_datetime(df.index)

    for signal in signals:
        signal_time = pd.to_datetime(signal['timestamp'])
        
        if signal['action'] in ['enter_long', 'enter_short']:
            if entry_signal is None:
                entry_signal = {
                    'timestamp': signal_time,
                    'price': signal['price'],
                    'action': signal['action']
                }
        elif signal['action'] in ['exit_long', 'exit_short'] and entry_signal is not None:
            exit_reason = 'stop_loss' if signal['action'].startswith('exit_') else 'take_profit'
            duration = (signal_time - entry_signal['timestamp']).total_seconds() / 60
            
            trade = {
                'entry_time': entry_signal['timestamp'],
                'entry_price': entry_signal['price'],
                'exit_time': signal_time,
                'exit_price': signal['price'],
                'position': entry_signal['action'],
                'exit_reason': exit_reason,
                'duration': duration
            }
            trades.append(trade)
            entry_signal = None

    # Close any open positions at the end
    if entry_signal is not None:
        last_time = df.index[-1]
        last_price = df['close'].iloc[-1]
        duration = (last_time - entry_signal['timestamp']).total_seconds() / 60
        
        trade = {
            'entry_time': entry_signal['timestamp'],
            'entry_price': entry_signal['price'],
            'exit_time': last_time,
            'exit_price': last_price,
            'position': entry_signal['action'],
            'exit_reason': 'end_of_data',
            'duration': duration
        }
        trades.append(trade)

    return trades

def plot_candlestick_analysis(df: pd.DataFrame, signals: Optional[List[Dict]] = None, ticker: str = '') -> None:
    """
    Plot candlestick chart with trade signals.
    
    Args:
        df: DataFrame with OHLCV data
        signals: Optional list of trading signals
        ticker: Stock ticker symbol
    """
    try:
        df_plot = df.copy()
        df_plot = df_plot[['open', 'high', 'low', 'close', 'volume']]
        df_plot.index.name = 'Date'

        # Create subplots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), sharex=True)
        fig.subplots_adjust(hspace=0)

        # Plot candlesticks and volume
        mc = mpf.make_marketcolors(up='g', down='r', inherit=True)
        s = mpf.make_mpf_style(marketcolors=mc)
        mpf.plot(df_plot, type='candle', style=s, ax=ax1, volume=ax2, show_nontrading=False)

        # Add signals if provided
        if signals:
            for signal in signals:
                if 'enter_long' in signal['action']:
                    ax1.scatter(signal['timestamp'], signal['price'], 
                              marker='^', color='g', s=100, label='Buy Signal')
                elif 'enter_short' in signal['action']:
                    ax1.scatter(signal['timestamp'], signal['price'], 
                              marker='v', color='r', s=100, label='Sell Signal')
                elif 'exit_long' in signal['action'] or 'exit_short' in signal['action']:
                    ax1.scatter(signal['timestamp'], signal['price'], 
                              marker='o', color='k', s=100, label='Exit Signal')

            handles, labels = ax1.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            ax1.legend(by_label.values(), by_label.keys())

        plt.title(f"Candlestick Chart with Trade Signals for {ticker}")
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.tight_layout()
        plt.show()

    except Exception as e:
        logger.error(f"Error in plot_candlestick_analysis: {e}")
        raise

def plot_trading_metrics(metrics: Dict[str, float], ticker: str) -> None:
    """
    Create comprehensive visualization of trading metrics.
    
    Args:
        metrics: Dictionary of trading metrics
        ticker: Stock ticker symbol
    """
    plt.style.use('default')
    fig = plt.figure(figsize=(20, 12))
    gs = gridspec.GridSpec(2, 2)

    # Rates Comparison
    ax1 = fig.add_subplot(gs[0, 0])
    rates = ['win_rate', 'loss_rate', 'net_win_rate', 'net_loss_rate']
    values = [metrics.get(rate, 0) for rate in rates]
    colors = ['green', 'red', 'lightgreen', 'lightcoral']
    ax1.bar(rates, values, color=colors)
    ax1.set_title('Trading Rates Comparison')
    ax1.set_ylabel('Percentage (%)')
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)

    # Profit Metrics
    ax2 = fig.add_subplot(gs[0, 1])
    profit_metrics = ['average_profit_per_trade', 'daily_profit', 'profit_per_minute']
    values = [metrics.get(metric, 0) for metric in profit_metrics]
    ax2.bar(profit_metrics, values, color='blue')
    ax2.set_title('Profit Metrics')
    ax2.set_ylabel('Amount ($)')
    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)

    # Risk Metrics
    ax3 = fig.add_subplot(gs[1, 0])
    risk_metrics = ['maximum_drawdown', 'sharpe_ratio', 'profit_factor']
    values = [metrics.get(metric, 0) for metric in risk_metrics]
    ax3.bar(risk_metrics, values, color='purple')
    ax3.set_title('Risk Metrics')
    plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45)

    # Trade Analysis
    ax4 = fig.add_subplot(gs[1, 1])
    trade_metrics = ['total_trades', 'average_trade_duration']
    values = [metrics.get(metric, 0) for metric in trade_metrics]
    ax4.bar(trade_metrics, values, color='orange')
    ax4.set_title('Trade Analysis')
    plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45)

    plt.suptitle(f'Trading Performance Metrics for {ticker}', fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_learning_curves(training_history: TrainingHistory, ticker: str) -> None:
    """
    Plot learning curves showing training and validation loss over epochs.
    
    Args:
        training_history: Object containing training and validation loss history
        ticker: Stock ticker symbol for plot title
    """
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(training_history.loss_history) + 1)
    
    plt.plot(epochs, training_history.loss_history, 'b-', label='Training Loss')
    plt.plot(epochs, training_history.validation_loss_history, 'r-', label='Validation Loss')
    
    plt.title(f'Learning Curves for {ticker}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.yscale('log')  # Use log scale for better visualization of loss changes
    plt.tight_layout()
    plt.show()

def plot_actual_vs_predicted(timestamps, y_true, y_pred, ticker: str) -> None:
    """
    Plot actual vs predicted prices.
    
    Args:
        timestamps: DatetimeIndex or array of timestamps
        y_true: Array of actual prices
        y_pred: Array of predicted prices
        ticker: Stock ticker symbol for plot title
    """
    plt.figure(figsize=(12, 6))
    
    plt.plot(timestamps, y_true, 'b-', label='Actual Price')
    plt.plot(timestamps, y_pred, 'r-', label='Predicted Price')
    
    # Calculate RMSE
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    
    plt.title(f'Actual vs Predicted Prices for {ticker} (RMSE: {rmse:.4f})')
    plt.xlabel('Time')
    plt.ylabel('Price')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def plot_model_prediction_quality(y_true, y_pred, timestamps, ticker):
    """
    Create a comprehensive visualization of model prediction quality.
    
    Args:
        y_true: Array of actual values
        y_pred: Array of predicted values
        timestamps: Array of corresponding timestamps
        ticker: Stock ticker symbol
    """
    # Create figure with multiple subplots
    fig = plt.figure(figsize=(20, 16))
    gs = gridspec.GridSpec(3, 2, height_ratios=[2, 1, 1])
    
    # 1. Time Series Plot
    ax1 = fig.add_subplot(gs[0, :])
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    correlation = np.corrcoef(y_true, y_pred)[0, 1]
    
    ax1.plot(timestamps, y_true, 'b-', label='Actual Price', linewidth=1.5)
    ax1.plot(timestamps, y_pred, 'r-', label='Predicted Price', linewidth=1.5)
    
    ax1.set_title(f'Actual vs Predicted Prices for {ticker}\nRMSE: {rmse:.4f}, MAE: {mae:.4f}, Correlation: {correlation:.4f}')
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Price')
    ax1.legend()
    ax1.grid(True)
    
    # 2. Prediction Error Over Time
    ax2 = fig.add_subplot(gs[1, 0])
    errors = y_pred - y_true
    ax2.plot(timestamps, errors, 'g-')
    ax2.axhline(y=0, color='r', linestyle='-')
    ax2.set_title('Prediction Error Over Time')
    ax2.set_xlabel('Time')
    ax2.set_ylabel('Error')
    ax2.grid(True)
    
    # 3. Error Distribution
    ax3 = fig.add_subplot(gs[1, 1])
    ax3.hist(errors, bins=50, alpha=0.75)
    ax3.axvline(x=0, color='r', linestyle='-')
    ax3.set_title('Error Distribution')
    ax3.set_xlabel('Error')
    ax3.set_ylabel('Frequency')
    
    # 4. Actual vs Predicted Scatter Plot
    ax4 = fig.add_subplot(gs[2, 0])
    ax4.scatter(y_true, y_pred, alpha=0.5)
    
    # Add perfect prediction line
    min_val = min(np.min(y_true), np.min(y_pred))
    max_val = max(np.max(y_true), np.max(y_pred))
    ax4.plot([min_val, max_val], [min_val, max_val], 'r--')
    
    ax4.set_title('Actual vs Predicted (Scatter)')
    ax4.set_xlabel('Actual Price')
    ax4.set_ylabel('Predicted Price')
    ax4.grid(True)
    
    # 5. Direction Accuracy
    ax5 = fig.add_subplot(gs[2, 1])
    
    # Calculate price changes
    true_changes = np.diff(y_true)
    pred_changes = np.diff(y_pred)
    
    # Determine direction accuracy
    true_direction = np.sign(true_changes)
    pred_direction = np.sign(pred_changes)
    correct_direction = (true_direction == pred_direction)
    
    direction_accuracy = np.mean(correct_direction) * 100
    
    # Create color-coded time series based on direction accuracy
    reduced_timestamps = timestamps[1:]  # Adjust for diff operation
    
    # Plot lines for correct and incorrect predictions
    accuracy_colors = ['green' if correct else 'red' for correct in correct_direction]
    
    for i in range(len(reduced_timestamps)-1):
        ax5.plot(reduced_timestamps[i:i+2], y_true[i:i+2], color=accuracy_colors[i], linewidth=1.5)
    
    ax5.set_title(f'Direction Prediction Accuracy: {direction_accuracy:.2f}%')
    ax5.set_xlabel('Time')
    ax5.set_ylabel('Price')
    
    # Create custom legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='green', lw=2, label='Correct Direction'),
        Line2D([0], [0], color='red', lw=2, label='Incorrect Direction')
    ]
    ax5.legend(handles=legend_elements)
    ax5.grid(True)
    
    plt.tight_layout()
    plt.show()

def main():
    """Main function for the high-frequency trading system."""
    
    # Load environment variables from .env file
    load_dotenv()
    
    # Configuration
    TICKERS = ['MSFT', 'GOOGL', 'TSLA', 'NVDA', 'TQQQ', 'SQQQ', 'QLD', 'PSQ']
    END_DATE = '2025-03-21'
    
    # Retrieve API key from environment variables
    API_KEY = os.getenv("POLYGON_API_KEY")
    if not API_KEY:
        logger.error("No API key found. Please check your .env file.")
        print("ERROR: Polygon API key not found. Please check your .env file.")
        return

    # Enhanced model hyperparameters
    MODEL_PARAMS = {
        'sequence_length': 60,
        'hidden_size': 128,
        'num_layers': 2,
        'batch_size': 64,
        'num_epochs': 100,
        'learning_rate': 0.0005,
        'dropout': 0.3,
        'weight_decay': 1e-5
    }

    # Initialize components
    fetcher = PolygonDataFetcher(API_KEY)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")

    for ticker in TICKERS:
        try:
            logger.info(f"Processing {ticker}")
            print(f"\nProcessing {ticker}...")

            # Initialize stock-specific components
            signal_generator = AdaptiveSignalGenerator(ticker)
            monitor = MultiTickerMonitor(signal_generator)

            # Define date range
            end_date = datetime.strptime(END_DATE, '%Y-%m-%d')
            start_date = end_date - timedelta(days=70)
            start_date_str = start_date.strftime('%Y-%m-%d')
            end_date_str = end_date.strftime('%Y-%m-%d')

            # Fetch and prepare market data
            logger.info(f"Fetching market data from {start_date_str} to {end_date_str}")
            stock_df = fetcher.fetch_stock_data(ticker, start_date_str, end_date_str)

            if stock_df.empty:
                logger.warning(f"No data fetched for {ticker}. Skipping.")
                continue

            # Calculate technical indicators
            stock_df['MA5'] = ta.trend.sma_indicator(stock_df['close'], window=5)
            stock_df['SMA_20'] = ta.trend.sma_indicator(stock_df['close'], window=20)
            stock_df['EMA_20'] = ta.trend.ema_indicator(stock_df['close'], window=20)
            stock_df['RSI'] = ta.momentum.rsi(stock_df['close'], window=14)
            
            macd_indicator = MACD(close=stock_df['close'])
            stock_df['MACD'] = macd_indicator.macd()
            stock_df['MACD_Signal'] = macd_indicator.macd_signal()
            
            bb_indicator = BollingerBands(close=stock_df['close'])
            stock_df['Bollinger_High'] = bb_indicator.bollinger_hband()
            stock_df['Bollinger_Low'] = bb_indicator.bollinger_lband()

            # Remove NaN values
            stock_df.dropna(inplace=True)

            # Prepare datasets using improved dataset class
            dataset = ImprovedHFTDataset(
                data=stock_df,
                sequence_length=MODEL_PARAMS['sequence_length'],
                predict_pct_change=True  # Predict percentage changes instead of absolute prices
            )

            # Split data
            train_size = int(0.7 * len(dataset))
            val_size = int(0.15 * len(dataset))
            test_size = len(dataset) - train_size - val_size

            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
                dataset, [train_size, val_size, test_size]
            )

            # Create data loaders
            train_loader = DataLoader(
                train_dataset,
                batch_size=MODEL_PARAMS['batch_size'],
                shuffle=True
            )
            val_loader = DataLoader(
                val_dataset,
                batch_size=MODEL_PARAMS['batch_size']
            )
            test_loader = DataLoader(
                test_dataset,
                batch_size=MODEL_PARAMS['batch_size']
            )

            # Initialize enhanced model
            input_size = dataset.features.shape[1]  # Get input size from dataset
            model = EnhancedLSTMModel(
                input_size=input_size,
                hidden_size=MODEL_PARAMS['hidden_size'],
                num_layers=MODEL_PARAMS['num_layers'],
                dropout=MODEL_PARAMS['dropout']
            )

            # Initialize training components
            criterion = DirectionalMSELoss(direction_weight=2.0, magnitude_weight=1.0)
            optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=MODEL_PARAMS['learning_rate'],
                weight_decay=MODEL_PARAMS['weight_decay']
            )
            
            # Learning rate scheduler
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, 
                mode='min', 
                factor=0.5, 
                patience=5, 
                min_lr=1e-6, 
                verbose=True
            )

            # Train enhanced model
            logger.info("Starting model training...")
            model, training_history = train_enhanced_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                criterion=criterion,
                optimizer=optimizer,
                num_epochs=MODEL_PARAMS['num_epochs'],
                device=device,
                learning_rate_scheduler=scheduler
            )

            # Plot learning curves
            logger.info("Plotting learning curves...")
            plt.figure(figsize=(12, 6))
            plt.plot(range(1, len(training_history['train_loss'])+1), training_history['train_loss'], 'b-', label='Training Loss')
            plt.plot(range(1, len(training_history['val_loss'])+1), training_history['val_loss'], 'r-', label='Validation Loss')
            plt.axvline(x=training_history['best_epoch'], color='g', linestyle='--', label=f'Best Epoch ({training_history["best_epoch"]})')
            plt.title(f'Learning Curves for {ticker}')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            plt.yscale('log')  # Use log scale for better visualization
            plt.tight_layout()
            plt.show()

            # Evaluate model
            logger.info("Evaluating model...")
            model.eval()
            y_true = []
            y_pred = []
            
            with torch.no_grad():
                for batch_x, batch_y in test_loader:
                    batch_x = batch_x.to(device)
                    outputs = model(batch_x).cpu().numpy()
                    y_pred.extend(outputs.flatten())
                    y_true.extend(batch_y.numpy().flatten())
            
            # Convert predictions from percentage changes back to prices if necessary
            test_indices = np.arange(MODEL_PARAMS['sequence_length'], len(stock_df))
            test_timestamps = stock_df.index[test_indices[-len(y_pred):]]
            
            # If predicting percentage changes, convert back to prices for plotting
            if dataset.predict_pct_change:
                # Get the base prices from the test data
                base_prices = stock_df['close'].iloc[test_indices[-len(y_pred):] - 1].values
                
                # Convert percentage changes to actual prices
                y_pred_prices = base_prices * (1 + np.array(y_pred))
                y_true_prices = base_prices * (1 + np.array(y_true))
            else:
                y_pred_prices = np.array(y_pred)
                y_true_prices = np.array(y_true)
            
            # Ensure arrays are properly shaped and aligned
            y_true_prices = np.array(y_true_prices).flatten()
            y_pred_prices = np.array(y_pred_prices).flatten()
            
            # Ensure same length (take the smaller length)
            min_length = min(len(y_true_prices), len(y_pred_prices), len(test_timestamps))
            y_true_prices = y_true_prices[:min_length]
            y_pred_prices = y_pred_prices[:min_length]
            test_timestamps_aligned = test_timestamps[:min_length]
            
            # Calculate RMSE with aligned arrays
            rmse = np.sqrt(mean_squared_error(y_true_prices, y_pred_prices))
            
            # Continue with plotting using the aligned arrays
            plt.figure(figsize=(14, 7))
            plt.plot(test_timestamps_aligned, y_true_prices, 'b-', label='Actual Price')
            plt.plot(test_timestamps_aligned, y_pred_prices, 'r-', label='Predicted Price')
            plt.title(f'Actual vs Predicted Prices for {ticker} (RMSE: {rmse:.4f})')
            plt.xlabel('Time')
            plt.ylabel('Price')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.show()
            
            # Pass aligned arrays to the prediction quality visualization
            logger.info("Generating comprehensive prediction quality visualization...")
            plot_model_prediction_quality(y_true_prices, y_pred_prices, test_timestamps_aligned, ticker)

            # Generate trading signals using the improved predictions
            all_signals = []
            
            print(f"Generating signals for {len(test_timestamps_aligned)} timestamps...")

            for idx, (timestamp, actual, predicted) in enumerate(zip(test_timestamps_aligned, y_true_prices, y_pred_prices)):
                # Find the original index in the stock_df that corresponds to this timestamp
                try:
                    orig_idx = stock_df.index.get_loc(timestamp)
                    market_data_window = stock_df.iloc[max(0, orig_idx-20):orig_idx+1].copy()
                    
                    if len(market_data_window) == 0:
                        continue

                    # Debug information for monitoring
                    if idx % 1000 == 0:
                        print(f"Processing index {idx}: Current price = {actual:.2f}, Predicted price = {predicted:.2f}")

                    # Generate signal
                    signal = monitor.update_ticker(
                        ticker=ticker,
                        current_price=actual,
                        predicted_price=predicted,
                        timestamp=timestamp,
                        market_data=market_data_window
                    )

                    if signal and signal['action']:
                        print(f"Generated signal at {timestamp}: {signal['action']}")
                        all_signals.append(signal)
                        
                except KeyError:
                    # Skip if timestamp is not found in the index
                    continue

            print(f"Generated {len(all_signals)} signals")

            # Perform backtesting
            trades = backtest_trades(all_signals, stock_df)
            print(f"Generated {len(trades)} trades")

            # Calculate performance metrics
            model_evaluator = ModelEvaluator()
            trade_metrics = model_evaluator.calculate_trade_performance_metrics(
                trades=trades,
                initial_capital=10000.0
            )

            # Print performance summary
            print("\nTrading Performance Summary:")
            print(f"Total Trades: {trade_metrics['total_trades']}")
            print(f"Win Rate: {trade_metrics['win_rate']:.2f}%")
            print(f"Total Profit: ${trade_metrics['total_profit']:.2f}")
            print(f"Maximum Drawdown: ${trade_metrics['maximum_drawdown']:.2f}")
            print(f"Sharpe Ratio: {trade_metrics['sharpe_ratio']:.2f}")

            # Generate visualizations
            plot_trading_metrics(trade_metrics, ticker)
            plot_candlestick_analysis(stock_df, signals=all_signals, ticker=ticker)
            
            # Save model and artifacts
            save_path = f'structured_models/{ticker}/'
            os.makedirs(save_path, exist_ok=True)
            
            model_filename = f'{save_path}enhanced_lstm_model_{ticker}_{end_date_str}.pth'
            scaler_filename = f'{save_path}enhanced_scaler_{ticker}_{end_date_str}.pkl'
            history_filename = f'{save_path}enhanced_training_history_{ticker}_{end_date_str}.pkl'
            
            torch.save(model.state_dict(), model_filename)
            
            # Save price and feature scalers
            with open(scaler_filename, 'wb') as f:
                pickle.dump({
                    'price_scaler': dataset.price_scaler,
                    'feature_scaler': dataset.feature_scaler
                }, f)
                
            with open(history_filename, 'wb') as f:
                pickle.dump(training_history, f)
                
            print(f"\nSaved model to {model_filename}")
            print(f"Saved scalers to {scaler_filename}")
            print(f"Saved training history to {history_filename}")

            logger.info(f"Completed processing for {ticker}")

        except Exception as e:
            logger.error(f"Error processing {ticker}: {e}")
            traceback.print_exc()
            continue
            
        finally:
            # Clean up memory
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    logger.info("Completed processing all tickers")

if __name__ == "__main__":
    main()
