In [None]:
# Deep Reinforcement Learning Trading System
# Complete Implementation with PPO, A2C, and DQN agents

import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import TimeSeriesSplit
import warnings
warnings.filterwarnings('ignore')

# Deep Learning libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from collections import deque
import random
import json
import os
from datetime import datetime

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Deep Reinforcement Learning Trading System")
print("=" * 60)

# ============================================================================
# 1. DATA COLLECTION AND PREPROCESSING (From Assignment 1)
# ============================================================================

class DataPipeline:
    def __init__(self, assets=['AAPL', 'MSFT', 'GOOGL', 'TSLA'], period='6y'):
        self.assets = assets
        self.period = period
        self.raw_data = {}
        self.processed_data = {}
        self.scalers = {}

    def fetch_data(self):
        """Fetch historical data for all assets"""
        print("Fetching historical data...")
        for symbol in self.assets:
            try:
                ticker = yf.Ticker(symbol)
                df = ticker.history(period=self.period)
                if not df.empty and len(df) >= 1500:
                    self.raw_data[symbol] = df
                    print(f"✓ {symbol}: {len(df)} records")
                else:
                    print(f"✗ Insufficient data for {symbol}")
            except Exception as e:
                print(f"✗ Error fetching {symbol}: {str(e)}")

    def calculate_technical_indicators(self, df):
        """Calculate comprehensive technical indicators"""
        data = df.copy()

        # Moving Averages
        data['SMA_20'] = data['Close'].rolling(window=20).mean()
        data['EMA_20'] = data['Close'].ewm(span=20).mean()
        data['SMA_50'] = data['Close'].rolling(window=50).mean()

        # RSI
        def calculate_rsi(prices, window=14):
            delta = prices.diff()
            gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
            loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
            rs = gain / loss
            return 100 - (100 / (1 + rs))

        data['RSI'] = calculate_rsi(data['Close'])

        # MACD
        ema_12 = data['Close'].ewm(span=12).mean()
        ema_26 = data['Close'].ewm(span=26).mean()
        data['MACD'] = ema_12 - ema_26
        data['MACD_signal'] = data['MACD'].ewm(span=9).mean()
        data['MACD_hist'] = data['MACD'] - data['MACD_signal']

        # Bollinger Bands
        data['BB_middle'] = data['Close'].rolling(window=20).mean()
        bb_std = data['Close'].rolling(window=20).std()
        data['BB_upper'] = data['BB_middle'] + (bb_std * 2)
        data['BB_lower'] = data['BB_middle'] - (bb_std * 2)
        data['BB_width'] = data['BB_upper'] - data['BB_lower']
        data['BB_position'] = (data['Close'] - data['BB_lower']) / (data['BB_upper'] - data['BB_lower'])

        # ATR
        high_low = data['High'] - data['Low']
        high_close = np.abs(data['High'] - data['Close'].shift())
        low_close = np.abs(data['Low'] - data['Close'].shift())
        ranges = pd.concat([high_low, high_close, low_close], axis=1)
        true_range = np.max(ranges, axis=1)
        data['ATR'] = true_range.rolling(14).mean()

        # Stochastic Oscillator
        low_min = data['Low'].rolling(window=14).min()
        high_max = data['High'].rolling(window=14).max()
        data['STOCH_K'] = 100 * (data['Close'] - low_min) / (high_max - low_min)
        data['STOCH_D'] = data['STOCH_K'].rolling(window=3).mean()

        # Williams %R
        data['Williams_R'] = -100 * (high_max - data['Close']) / (high_max - low_min)

        # Volume indicators
        data['Volume_SMA'] = data['Volume'].rolling(window=20).mean()
        data['Volume_ratio'] = data['Volume'] / data['Volume_SMA']

        # Price-based features
        data['Price_change'] = data['Close'].pct_change()
        data['High_Low_ratio'] = data['High'] / data['Low']
        data['Close_Open_ratio'] = data['Close'] / data['Open']

        # Additional DRL-specific features
        data['momentum_5'] = data['Close'].pct_change(5)
        data['momentum_10'] = data['Close'].pct_change(10)
        data['vol_regime'] = data['ATR'].rolling(20).mean() / data['ATR'].rolling(60).mean()
        data['trend_strength'] = abs(data['MACD'] - data['MACD_signal'])
        data['mean_reversion'] = (data['Close'] - data['SMA_20']) / data['ATR']
        data['risk_adj_return'] = data['Price_change'] / data['ATR']
        data['relative_volume'] = data['Volume'] / data['Volume'].rolling(20).mean()

        return data

    def preprocess_data(self, df):
        """Preprocess and normalize data"""
        df_clean = df.dropna()

        # Separate price and technical indicators
        price_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
        tech_cols = [col for col in df_clean.columns if col not in price_cols]

        # Scale features
        price_scaler = MinMaxScaler()
        tech_scaler = StandardScaler()

        df_normalized = df_clean.copy()
        df_normalized[price_cols] = price_scaler.fit_transform(df_clean[price_cols])

        if len(tech_cols) > 0:
            df_normalized[tech_cols] = tech_scaler.fit_transform(df_clean[tech_cols])

        return df_normalized, df_clean, price_scaler, tech_scaler

    def process_all_assets(self):
        """Process all assets with feature engineering"""
        print("Processing assets with feature engineering...")

        for symbol, df in self.raw_data.items():
            print(f"Processing {symbol}...")

            # Calculate technical indicators
            engineered_df = self.calculate_technical_indicators(df)

            # Preprocess data
            normalized, clean, p_scaler, t_scaler = self.preprocess_data(engineered_df)

            self.processed_data[symbol] = {
                'normalized': normalized,
                'clean': clean,
                'raw': df
            }

            self.scalers[symbol] = {
                'price_scaler': p_scaler,
                'tech_scaler': t_scaler
            }

            print(f"✓ {symbol}: {len(clean)} samples, {len(clean.columns)} features")

# ============================================================================
# 2. TRADING ENVIRONMENT
# ============================================================================

class TradingEnvironment:
    def __init__(self, data, initial_balance=10000, transaction_cost=0.001):
        self.data = data.reset_index(drop=True)
        self.initial_balance = initial_balance
        self.transaction_cost = transaction_cost
        self.current_step = 0
        self.balance = initial_balance
        self.shares_held = 0
        self.total_shares_sold = 0
        self.total_sales_value = 0
        self.net_worth = initial_balance
        self.max_net_worth = initial_balance
        self.trades = []

        # State features (excluding OHLCV for state representation)
        self.feature_columns = [col for col in data.columns
                               if col not in ['Open', 'High', 'Low', 'Close', 'Volume']]

        # Action space: 0=Hold, 1=Buy, 2=Sell
        self.action_space = 3
        self.observation_space = len(self.feature_columns) + 3  # +3 for portfolio state

    def reset(self):
        """Reset environment to initial state"""
        self.current_step = 0
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_shares_sold = 0
        self.total_sales_value = 0
        self.net_worth = self.initial_balance
        self.max_net_worth = self.initial_balance
        self.trades = []

        return self._get_observation()

    def _get_observation(self):
        """Get current observation state"""
        if self.current_step >= len(self.data):
            return np.zeros(self.observation_space)

        # Technical indicators
        tech_features = self.data[self.feature_columns].iloc[self.current_step].values

        # Portfolio state
        current_price = self.data['Close'].iloc[self.current_step]
        portfolio_value = self.balance + self.shares_held * current_price

        portfolio_state = np.array([
            self.balance / self.initial_balance,  # Normalized balance
            self.shares_held * current_price / self.initial_balance,  # Normalized holdings value
            portfolio_value / self.initial_balance  # Normalized total value
        ])

        # Handle NaN values
        tech_features = np.nan_to_num(tech_features, nan=0.0)
        portfolio_state = np.nan_to_num(portfolio_state, nan=0.0)

        return np.concatenate([tech_features, portfolio_state])

    def step(self, action):
        """Execute action and return next state, reward, done"""
        if self.current_step >= len(self.data) - 1:
            return self._get_observation(), 0, True, {}

        current_price = self.data['Close'].iloc[self.current_step]

        # Execute action
        if action == 1:  # Buy
            shares_to_buy = self.balance // (current_price * (1 + self.transaction_cost))
            if shares_to_buy > 0:
                cost = shares_to_buy * current_price * (1 + self.transaction_cost)
                self.balance -= cost
                self.shares_held += shares_to_buy
                self.trades.append({
                    'step': self.current_step,
                    'action': 'BUY',
                    'shares': shares_to_buy,
                    'price': current_price,
                    'cost': cost
                })

        elif action == 2:  # Sell
            if self.shares_held > 0:
                revenue = self.shares_held * current_price * (1 - self.transaction_cost)
                self.balance += revenue
                self.total_sales_value += revenue
                self.total_shares_sold += self.shares_held
                self.trades.append({
                    'step': self.current_step,
                    'action': 'SELL',
                    'shares': self.shares_held,
                    'price': current_price,
                    'revenue': revenue
                })
                self.shares_held = 0

        # Move to next step
        self.current_step += 1

        # Calculate reward
        new_net_worth = self.balance + self.shares_held * self.data['Close'].iloc[self.current_step]
        reward = (new_net_worth - self.net_worth) / self.initial_balance

        # Update tracking variables
        self.net_worth = new_net_worth
        self.max_net_worth = max(self.max_net_worth, new_net_worth)

        # Check if done
        done = self.current_step >= len(self.data) - 1

        return self._get_observation(), reward, done, {
            'net_worth': self.net_worth,
            'balance': self.balance,
            'shares_held': self.shares_held
        }

    def get_performance_metrics(self):
        """Calculate performance metrics"""
        if len(self.trades) == 0:
            return {
                'total_return': 0,
                'num_trades': 0,
                'win_rate': 0,
                'sharpe_ratio': 0,
                'max_drawdown': 0
            }

        total_return = (self.net_worth - self.initial_balance) / self.initial_balance

        # Calculate other metrics
        returns = []
        for i in range(1, len(self.data)):
            if i <= self.current_step:
                price_change = (self.data['Close'].iloc[i] - self.data['Close'].iloc[i-1]) / self.data['Close'].iloc[i-1]
                returns.append(price_change)

        returns = np.array(returns)
        sharpe_ratio = np.mean(returns) / np.std(returns) if np.std(returns) > 0 else 0

        # Max drawdown
        max_drawdown = (self.max_net_worth - self.net_worth) / self.max_net_worth

        return {
            'total_return': total_return,
            'num_trades': len(self.trades),
            'sharpe_ratio': sharpe_ratio,
            'max_drawdown': max_drawdown,
            'final_balance': self.balance,
            'final_shares': self.shares_held,
            'net_worth': self.net_worth
        }

# ============================================================================
# 3. DEEP REINFORCEMENT LEARNING AGENTS
# ============================================================================

class DQNAgent:
    def __init__(self, state_size, action_size, learning_rate=0.001):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = learning_rate

        # Neural Network
        self.q_network = self._build_network()
        self.target_network = self._build_network()
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)

        # Update target network
        self.update_target_network()

    def _build_network(self):
        """Build DQN neural network"""
        model = nn.Sequential(
            nn.Linear(self.state_size, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, self.action_size)
        )
        return model

    def update_target_network(self):
        """Update target network with main network weights"""
        self.target_network.load_state_dict(self.q_network.state_dict())

    def remember(self, state, action, reward, next_state, done):
        """Store experience in replay buffer"""
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        """Choose action using epsilon-greedy policy"""
        if np.random.random() <= self.epsilon:
            return random.randrange(self.action_size)

        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        q_values = self.q_network(state_tensor)
        return np.argmax(q_values.cpu().data.numpy())

    def replay(self, batch_size=32):
        """Train the model on a batch of experiences"""
        if len(self.memory) < batch_size:
            return

        batch = random.sample(self.memory, batch_size)
        states = torch.FloatTensor([e[0] for e in batch])
        actions = torch.LongTensor([e[1] for e in batch])
        rewards = torch.FloatTensor([e[2] for e in batch])
        next_states = torch.FloatTensor([e[3] for e in batch])
        dones = torch.BoolTensor([e[4] for e in batch])

        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        next_q_values = self.target_network(next_states).max(1)[0].detach()
        target_q_values = rewards + (0.95 * next_q_values * ~dones)

        loss = F.mse_loss(current_q_values.squeeze(), target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

class PPOAgent:
    def __init__(self, state_size, action_size, learning_rate=0.0003):
        self.state_size = state_size
        self.action_size = action_size
        self.learning_rate = learning_rate

        # Actor-Critic networks
        self.actor = self._build_actor()
        self.critic = self._build_critic()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)

        # PPO hyperparameters
        self.clip_epsilon = 0.2
        self.ppo_epochs = 4
        self.entropy_coef = 0.01

    def _build_actor(self):
        """Build actor network"""
        return nn.Sequential(
            nn.Linear(self.state_size, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, self.action_size),
            nn.Softmax(dim=-1)
        )

    def _build_critic(self):
        """Build critic network"""
        return nn.Sequential(
            nn.Linear(self.state_size, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def act(self, state):
        """Choose action using policy"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action_probs = self.actor(state_tensor)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action).item()

    def evaluate(self, state, action):
        """Evaluate state-action pair"""
        state_tensor = torch.FloatTensor(state)
        action_tensor = torch.LongTensor(action)

        action_probs = self.actor(state_tensor)
        dist = Categorical(action_probs)
        action_log_probs = dist.log_prob(action_tensor)
        entropy = dist.entropy()

        state_value = self.critic(state_tensor)

        return action_log_probs, state_value, entropy

    def update(self, states, actions, rewards, log_probs, values, next_values):
        """Update policy using PPO"""
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        old_log_probs = torch.FloatTensor(log_probs)
        values = torch.FloatTensor(values)
        next_values = torch.FloatTensor(next_values)

        # Calculate advantages
        advantages = rewards + 0.99 * next_values - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO update
        for _ in range(self.ppo_epochs):
            log_probs_new, values_new, entropy = self.evaluate(states, actions)

            # Actor loss
            ratio = torch.exp(log_probs_new - old_log_probs)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coef * entropy.mean()

            # Critic loss
            critic_loss = F.mse_loss(values_new.squeeze(), rewards + 0.99 * next_values)

            # Update networks
            self.actor_optimizer.zero_grad()
            actor_loss.backward(retain_graph=True)
            self.actor_optimizer.step()

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

# ============================================================================
# 4. TRAINING FRAMEWORK
# ============================================================================

class TradingTrainer:
    def __init__(self, data_pipeline, agent_type='DQN'):
        self.data_pipeline = data_pipeline
        self.agent_type = agent_type
        self.results = {}

    def train_agent(self, symbol, episodes=1000):
        """Train agent on specific symbol"""
        print(f"\nTraining {self.agent_type} agent on {symbol}...")

        # Prepare data
        data = self.data_pipeline.processed_data[symbol]['normalized']

        # Split data
        train_size = int(len(data) * 0.8)
        train_data = data[:train_size]
        test_data = data[train_size:]

        # Create environment
        env = TradingEnvironment(train_data)

        # Initialize agent
        if self.agent_type == 'DQN':
            agent = DQNAgent(env.observation_space, env.action_space)
        elif self.agent_type == 'PPO':
            agent = PPOAgent(env.observation_space, env.action_space)
        else:
            raise ValueError(f"Unknown agent type: {self.agent_type}")

        # Training loop
        episode_rewards = []
        episode_returns = []

        for episode in range(episodes):
            state = env.reset()
            episode_reward = 0
            done = False

            # For PPO
            if self.agent_type == 'PPO':
                states, actions, rewards, log_probs, values = [], [], [], [], []

            while not done:
                if self.agent_type == 'DQN':
                    action = agent.act(state)
                    next_state, reward, done, _ = env.step(action)
                    agent.remember(state, action, reward, next_state, done)
                    state = next_state
                    episode_reward += reward

                    if len(agent.memory) > 32:
                        agent.replay()

                elif self.agent_type == 'PPO':
                    action, log_prob = agent.act(state)
                    next_state, reward, done, _ = env.step(action)

                    # Store transition
                    states.append(state)
                    actions.append(action)
                    rewards.append(reward)
                    log_probs.append(log_prob)
                    values.append(agent.critic(torch.FloatTensor(state).unsqueeze(0)).item())

                    state = next_state
                    episode_reward += reward

            # Update PPO agent
            if self.agent_type == 'PPO' and len(states) > 0:
                next_values = values[1:] + [0]  # Bootstrap with 0 for terminal state
                agent.update(states, actions, rewards, log_probs, values, next_values)

            # Update target network for DQN
            if self.agent_type == 'DQN' and episode % 100 == 0:
                agent.update_target_network()

            # Track progress
            episode_rewards.append(episode_reward)
            performance = env.get_performance_metrics()
            episode_returns.append(performance['total_return'])

            if episode % 100 == 0:
                avg_reward = np.mean(episode_rewards[-100:])
                avg_return = np.mean(episode_returns[-100:])
                print(f"Episode {episode}, Avg Reward: {avg_reward:.4f}, Avg Return: {avg_return:.4f}")

        # Test on unseen data
        print(f"Testing {self.agent_type} agent on {symbol}...")
        test_env = TradingEnvironment(test_data)
        test_state = test_env.reset()
        test_done = False

        while not test_done:
            if self.agent_type == 'DQN':
                # Use trained policy without exploration
                agent.epsilon = 0
                test_action = agent.act(test_state)
            else:  # PPO
                test_action, _ = agent.act(test_state)

            test_state, _, test_done, _ = test_env.step(test_action)

        # Store results
        test_performance = test_env.get_performance_metrics()
        self.results[symbol] = {
            'agent_type': self.agent_type,
            'training_rewards': episode_rewards,
            'training_returns': episode_returns,
            'test_performance': test_performance,
            'test_env': test_env,
            'agent': agent
        }

        print(f"✓ {symbol} - Test Return: {test_performance['total_return']:.4f}")
        return agent, test_performance

# ============================================================================
# 5. MAIN EXECUTION
# ============================================================================

def main():
    # Initialize data pipeline
    pipeline = DataPipeline()

    # Fetch and process data
    pipeline.fetch_data()
    pipeline.process_all_assets()

    # Train agents for each asset
    results = {}

    for agent_type in ['DQN', 'PPO']:
        print(f"\n{'='*60}")
        print(f"Training {agent_type} agents")
        print('='*60)

        trainer = TradingTrainer(pipeline, agent_type)

        for symbol in pipeline.assets:
            if symbol in pipeline.processed_data:
                agent, performance = trainer.train_agent(symbol, episodes=500)
                results[f"{agent_type}_{symbol}"] = {
                    'agent': agent,
                    'performance': performance,
                    'trainer': trainer
                }

    # Display results summary
    print(f"\n{'='*60}")
    print("FINAL RESULTS SUMMARY")
    print('='*60)

    for key, result in results.items():
        agent_type, symbol = key.split('_', 1)
        performance = result['performance']
        print(f"{agent_type} - {symbol}:")
        print(f"  Total Return: {performance['total_return']:.4f}")
        print(f"  Sharpe Ratio: {performance['sharpe_ratio']:.4f}")
        print(f"  Number of Trades: {performance['num_trades']}")
        print(f"  Max Drawdown: {performance['max_drawdown']:.4f}")
        print()

    return pipeline, results

if __name__ == "__main__":
    pipeline, results = main()
    print("Training completed! Results stored in 'results' variable.")

Deep Reinforcement Learning Trading System
Fetching historical data...
✓ AAPL: 1508 records
✓ MSFT: 1508 records
✓ GOOGL: 1508 records
✓ TSLA: 1508 records
Processing assets with feature engineering...
Processing AAPL...
✓ AAPL: 1436 samples, 35 features
Processing MSFT...
✓ MSFT: 1436 samples, 35 features
Processing GOOGL...
✓ GOOGL: 1436 samples, 35 features
Processing TSLA...
✓ TSLA: 1436 samples, 35 features

Training DQN agents

Training DQN agent on AAPL...
Episode 0, Avg Reward: 3.4951, Avg Return: 3.4951
Episode 100, Avg Reward: nan, Avg Return: nan
Episode 200, Avg Reward: 2.4484, Avg Return: 2.4484
Episode 300, Avg Reward: nan, Avg Return: nan
Episode 400, Avg Reward: 2.3051, Avg Return: 2.3051
Testing DQN agent on AAPL...
✓ AAPL - Test Return: 0.0000

Training DQN agent on MSFT...
Episode 0, Avg Reward: -0.8279, Avg Return: -0.8279
Episode 100, Avg Reward: nan, Avg Return: nan


In [None]:
import streamlit as st
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import yfinance as yf
from datetime import datetime, timedelta
import json
import pickle
import os

# Import our main trading system
from drl_trading_system import DataPipeline, TradingEnvironment, DQNAgent, PPOAgent, TradingTrainer

# Configure Streamlit page
st.set_page_config(
    page_title="DRL Trading System Dashboard",
    page_icon="📈",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        font-weight: bold;
        color: #1f77b4;
        text-align: center;
        margin-bottom: 2rem;
    }
    .metric-container {
        background-color: #f0f2f6;
        padding: 1rem;
        border-radius: 0.5rem;
        margin: 0.5rem 0;
    }
    .performance-card {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        padding: 1.5rem;
        border-radius: 1rem;
        margin: 1rem 0;
        text-align: center;
    }
</style>
""", unsafe_allow_html=True)

# Initialize session state
if 'pipeline' not in st.session_state:
    st.session_state.pipeline = None
if 'results' not in st.session_state:
    st.session_state.results = {}
if 'trained_agents' not in st.session_state:
    st.session_state.trained_agents = {}

# Main header
st.markdown('<h1 class="main-header">🚀 Deep Reinforcement Learning Trading System</h1>', unsafe_allow_html=True)

# Sidebar
st.sidebar.title("Navigation")
page = st.sidebar.selectbox("Choose a page", [
    "Overview",
    "Data Pipeline",
    "Agent Training",
    "Performance Analysis",
    "Live Trading Simulation",
    "Model Comparison"
])

# Data loading functions
@st.cache_data
def load_data():
    """Load and process data with caching"""
    pipeline = DataPipeline()
    pipeline.fetch_data()
    pipeline.process_all_assets()
    return pipeline

def save_results(results, filename="trading_results.pkl"):
    """Save results to file"""
    with open(filename, 'wb') as f:
        pickle.dump(results, f)

def load_results(filename="trading_results.pkl"):
    """Load results from file"""
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except FileNotFoundError:
        return {}

# ============================================================================
# OVERVIEW PAGE
# ============================================================================

if page == "Overview":
    st.header("📊 System Overview")

    col1, col2, col3 = st.columns(3)

    with col1:
        st.markdown("""
        <div class="performance-card">
            <h3>🎯 Objective</h3>
            <p>Develop and compare DRL agents for automated stock trading using technical indicators and market data.</p>
        </div>
        """, unsafe_allow_html=True)

    with col2:
        st.markdown("""
        <div class="performance-card">
            <h3>🤖 Agents</h3>
            <p>DQN (Deep Q-Network) and PPO (Proximal Policy Optimization) agents with neural network policies.</p>
        </div>
        """, unsafe_allow_html=True)

    with col3:
        st.markdown("""
        <div class="performance-card">
            <h3>📈 Assets</h3>
            <p>AAPL, MSFT, GOOGL, TSLA with 6 years of historical data and 25+ technical indicators.</p>
        </div>
        """, unsafe_allow_html=True)

    st.subheader("🔧 System Architecture")

    # Architecture diagram (simplified)
    fig = go.Figure()
    fig.add_shape(
        type="rect", x0=0, y0=0, x1=2, y1=1,
        line=dict(color="blue"), fillcolor="lightblue"
    )
    fig.add_annotation(x=1, y=0.5, text="Data Pipeline<br>Technical Indicators", showarrow=False)

    fig.add_shape(
        type="rect", x0=3, y0=0, x1=5, y1=1,
        line=dict(color="green"), fillcolor="lightgreen"
    )
    fig.add_annotation(x=4, y=0.5, text="Trading Environment<br>State/Action/Reward", showarrow=False)

    fig.add_shape(
        type="rect", x0=6, y0=0, x1=8, y1=1,
        line=dict(color="red"), fillcolor="lightcoral"
    )
    fig.add_annotation(x=7, y=0.5, text="DRL Agents<br>DQN/PPO", showarrow=False)

    # Add arrows
    fig.add_annotation(x=2.5, y=0.5, text="→", showarrow=False, font=dict(size=20))
    fig.add_annotation(x=5.5, y=0.5, text="→", showarrow=False, font=dict(size=20))

    fig.update_layout(
        title="DRL Trading System Architecture",
        xaxis=dict(range=[-0.5, 8.5], showgrid=False, showticklabels=False),
        yaxis=dict(range=[-0.5, 1.5], showgrid=False, showticklabels=False),
        height=200
    )

    st.plotly_chart(fig, use_container_width=True)

    st.subheader("📋 Key Features")

    features = [
        "🔍 Comprehensive technical indicator calculation (RSI, MACD, Bollinger Bands, etc.)",
        "🧠 Deep Q-Network (DQN) and Proximal Policy Optimization (PPO) agents",
        "🎯 Multi-asset trading with risk management",
        "📊 Real-time performance monitoring and visualization",
        "🔄 Backtesting and forward testing capabilities",
        "💾 Model persistence and result caching"
    ]

    for feature in features:
        st.write(feature)

# ============================================================================
# DATA PIPELINE PAGE
# ============================================================================

elif page == "Data Pipeline":
    st.header("📈 Data Pipeline & Feature Engineering")

    # Load data button
    if st.button("🔄 Load Fresh Data"):
        with st.spinner("Loading data..."):
            st.session_state.pipeline = load_data()
        st.success("Data loaded successfully!")

    # Check if data is loaded
    if st.session_state.pipeline is None:
        st.warning("Please load data first using the button above.")
        st.stop()

    pipeline = st.session_state.pipeline

    # Data overview
    st.subheader("📊 Data Overview")

    if pipeline.processed_data:
        # Create summary table
        summary_data = []
        for symbol in pipeline.assets:
            if symbol in pipeline.processed_data:
                data = pipeline.processed_data[symbol]['clean']
                summary_data.append({
                    'Symbol': symbol,
                    'Records': len(data),
                    'Features': len(data.columns),
                    'Start Date': data.index[0].strftime('%Y-%m-%d'),
                    'End Date': data.index[-1].strftime('%Y-%m-%d'),
                    'Latest Price': f"${data['Close'].iloc[-1]:.2f}"
                })

        summary_df = pd.DataFrame(summary_data)
        st.dataframe(summary_df, use_container_width=True)

        # Feature categories
        st.subheader("🔧 Feature Categories")

        col1, col2 = st.columns(2)

        with col1:
            st.markdown("""
            **Trend Indicators:**
            - SMA (20, 50)
            - EMA (20)
            - MACD & Signal

            **Momentum Indicators:**
            - RSI (14)
            - Stochastic (K, D)
            - Williams %R
            - ROC
            """)

        with col2:
            st.markdown("""
            **Volatility Indicators:**
            - ATR (14)
            - Bollinger Bands
            - Volatility Regime

            **Volume Indicators:**
            - Volume SMA
            - Volume Ratio
            - Relative Volume
            """)

        # Interactive chart
        st.subheader("📊 Interactive Price Chart")

        selected_symbol = st.selectbox("Select Symbol", pipeline.assets)

        if selected_symbol in pipeline.processed_data:
            data = pipeline.processed_data[selected_symbol]['clean']

            # Create candlestick chart
            fig = make_subplots(
                rows=3, cols=1,
                subplot_titles=('Price & Moving Averages', 'RSI', 'MACD'),
                vertical_spacing=0.1,
                row_heights=[0.6, 0.2, 0.2]
            )

            # Candlestick chart
            fig.add_trace(
                go.Candlestick(
                    x=data.index,
                    open=data['Open'],
                    high=data['High'],
                    low=data['Low'],
                    close=data['Close'],
                    name='Price'
                ),
                row=1, col=1
            )

            # Moving averages
            fig.add_trace(
                go.Scatter(x=data.index, y=data['SMA_20'], name='SMA 20', line=dict(color='orange')),
                row=1, col=1
            )
            fig.add_trace(
                go.Scatter(x=data.index, y=data['EMA_20'], name='EMA 20', line=dict(color='red')),
                row=1, col=1
            )

            # RSI
            fig.add_trace(
                go.Scatter(x=data.index, y=data['RSI'], name='RSI', line=dict(color='purple')),
                row=2, col=1
            )
            fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=1)
            fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=1)

            # MACD
            fig.add_trace(
                go.Scatter(x=data.index, y=data['MACD'], name='MACD', line=dict(color='blue')),
                row=3, col=1
            )
            fig.add_trace(
                go.Scatter(x=data.index, y=data['MACD_signal'], name='Signal', line=dict(color='red')),
                row=3, col=1
            )

            fig.update_layout(
                title=f"{selected_symbol} Technical Analysis",
                height=800,
                showlegend=True
            )

            st.plotly_chart(fig, use_container_width=True)

            # Correlation matrix
            st.subheader("🔗 Feature Correlation Matrix")

            # Select key features for correlation
            corr_features = ['Close', 'RSI', 'MACD', 'ATR', 'STOCH_K', 'Volume_ratio', 'BB_position']
            corr_data = data[corr_features].corr()

            fig_corr = px.imshow(
                corr_data,
                labels=dict(x="Features", y="Features", color="Correlation"),
                x=corr_features,
                y=corr_features,
                color_continuous_scale='RdBu',
                aspect="auto"
            )
            fig_corr.update_layout(title="Feature Correlation Matrix")
            st.plotly_chart(fig_corr, use_container_width=True)

# ============================================================================
# AGENT TRAINING PAGE
# ============================================================================

elif page == "Agent Training":
    st.header("🤖 Agent Training & Configuration")

    # Check if data is loaded
    if st.session_state.pipeline is None:
        st.warning("Please load data first from the Data Pipeline page.")
        st.stop()

    pipeline = st.session_state.pipeline

    # Training configuration
    st.subheader("⚙️ Training Configuration")

    col1, col2 = st.columns(2)

    with col1:
        agent_type = st.selectbox("Select Agent Type", ["DQN", "PPO"])
        selected_symbols = st.multiselect("Select Assets", pipeline.assets, default=pipeline.assets[:2])
        episodes = st.slider("Training Episodes", 100, 2000, 500)

    with col2:
        learning_rate = st.number_input("Learning Rate", 0.0001, 0.01, 0.001, format="%.4f")
        batch_size = st.number_input("Batch Size", 16, 128, 32)

        if agent_type == "DQN":
            epsilon_decay = st.number_input("Epsilon Decay", 0.990, 0.999, 0.995, format="%.3f")
        else:  # PPO
            clip_epsilon = st.number_input("Clip Epsilon", 0.1, 0.3, 0.2, format="%.1f")

    # Training button
    if st.button("🚀 Start Training"):
        if not selected_symbols:
            st.error("Please select at least one asset.")
            st.stop()

        progress_bar = st.progress(0)
        status_text = st.empty()

        training_results = {}

        for i, symbol in enumerate(selected_symbols):
            status_text.text(f"Training {agent_type} agent on {symbol}...")

            # Create trainer
            trainer = TradingTrainer(pipeline, agent_type)

            # Train agent
            with st.spinner(f"Training {agent_type} on {symbol}..."):
                agent, performance = trainer.train_agent(symbol, episodes)

                training_results[f"{agent_type}_{symbol}"] = {
                    'agent': agent,
                    'performance': performance,
                    'trainer': trainer
                }

            progress_bar.progress((i + 1) / len(selected_symbols))

        # Store results
        st.session_state.results.update(training_results)
        st.session_state.trained_agents.update(training_results)

        # Save results
        save_results(st.session_state.results)

        status_text.text("Training completed!")
        st.success(f"Successfully trained {agent_type} agents on {len(selected_symbols)} assets!")

        # Display training summary
        st.subheader("📊 Training Summary")

        summary_data = []
        for key, result in training_results.items():
            agent_type_name, symbol = key.split('_', 1)
            performance = result['performance']
            summary_data.append({
                'Agent': agent_type_name,
                'Symbol': symbol,
                'Total Return': f"{performance['total_return']:.2%}",
                'Sharpe Ratio': f"{performance['sharpe_ratio']:.3f}",
                'Trades': performance['num_trades'],
                'Max Drawdown': f"{performance['max_drawdown']:.2%}"
            })

        summary_df = pd.DataFrame(summary_data)
        st.dataframe(summary_df, use_container_width=True)

    # Display existing results
    if st.session_state.results:
        st.subheader("🗂️ Existing Training Results")

        existing_results = []
        for key, result in st.session_state.results.items():
            try:
                agent_type_name, symbol = key.split('_', 1)
                performance = result['performance']
                existing_results.append({
                    'Agent': agent_type_name,
                    'Symbol': symbol,
                    'Total Return': f"{performance['total_return']:.2%}",
                    'Sharpe Ratio': f"{performance['sharpe_ratio']:.3f}",
                    'Trades': performance['num_trades'],
                    'Max Drawdown': f"{performance['max_drawdown']:.2%}"
                })
            except:
                continue

        if existing_results:
            existing_df = pd.DataFrame(existing_results)
            st.dataframe(existing_df, use_container_width=True)

            # Clear results button
            if st.button("🗑️ Clear All Results"):
                st.session_state.results = {}
                st.session_state.trained_agents = {}
                st.success("All results cleared!")

# ============================================================================
# PERFORMANCE ANALYSIS PAGE
# ============================================================================

elif page == "Performance Analysis":
    st.header("📊 Performance Analysis")

    # Check if results exist
    if not st.session_state.results:
        st.warning("No trained agents found. Please train some agents first.")
        st.stop()

    # Results selector
    result_keys = list(st.session_state.results.keys())
    selected_result = st.selectbox("Select Trained Agent", result_keys)

    if selected_result:
        result = st.session_state.results[selected_result]
        performance = result['performance']
        trainer = result['trainer']

        # Performance metrics
        st.subheader("📈 Performance Metrics")

        col1, col2, col3, col4 = st.columns(4)

        with col1:
            st.metric("Total Return", f"{performance['total_return']:.2%}")

        with col2:
            st.metric("Sharpe Ratio", f"{performance['sharpe_ratio']:.3f}")

        with col3:
            st.metric("Number of Trades", performance['num_trades'])

        with col4:
            st.metric("Max Drawdown", f"{performance['max_drawdown']:.2%}")

        # Additional metrics
        col5, col6, col7, col8 = st.columns(4)

        with col5:
            st.metric("Final Balance", f"${performance['final_balance']:.2f}")

        with col6:
            st.metric("Final Shares", performance['final_shares'])

        with col7:
            st.metric("Net Worth", f"${performance['net_worth']:.2f}")

        with col8:
            roi = (performance['net_worth'] - 10000) / 10000
            st.metric("ROI", f"{roi:.2%}")

        # Training progress charts
        if selected_result in trainer.results:
            training_data = trainer.results[selected_result.split('_', 1)[1]]

            st.subheader("📊 Training Progress")

            col1, col2 = st.columns(2)

            with col1:
                # Training rewards
                fig_rewards = go.Figure()
                fig_rewards.add_trace(go.Scatter(
                    y=training_data['training_rewards'],
                    mode='lines',
                    name='Episode Rewards',
                    line=dict(color='blue')
                ))

                # Add moving average
                window = 50
                if len(training_data['training_rewards']) > window:
                    ma_rewards = np.convolve(training_data['training_rewards'],
                                           np.ones(window)/window, mode='valid')
                    fig_rewards.add_trace(go.Scatter(
                        y=ma_rewards,
                        mode='lines',
                        name=f'MA({window})',
                        line=dict(color='red')
                    ))

                fig_rewards.update_layout(
                    title="Training Rewards",
                    xaxis_title="Episode",
                    yaxis_title="Reward"
                )
                st.plotly_chart(fig_rewards, use_container_width=True)

            with col2:
                # Training returns
                fig_returns = go.Figure()
                fig_returns.add_trace(go.Scatter(
                    y=training_data['training_returns'],
                    mode='lines',
                    name='Episode Returns',
                    line=dict(color='green')
                ))

                # Add moving average
                if len(training_data['training_returns']) > window:
                    ma_returns = np.convolve(training_data['training_returns'],
                                           np.ones(window)/window, mode='valid')
                    fig_returns.add_trace(go.Scatter(
                        y=ma_returns,
                        mode='lines',
                        name=f'MA({window})',
                        line=dict(color='orange')
                    ))

                fig_returns.update_layout(
                    title="Training Returns",
                    xaxis_title="Episode",
                    yaxis_title="Return"
                )
                st.plotly_chart(fig_returns, use_container_width=True)

        # Trading actions visualization
        st.subheader("🎯 Trading Actions")

        if 'test_env' in training_data:
            test_env = training_data['test_env']

            if test_env.trades:
                # Create trades DataFrame
                trades_df = pd.DataFrame(test_env.trades)

                # Trading actions pie chart
                action_counts = trades_df['action'].value_counts()

                fig_actions = go.Figure(data=[go.Pie(
                    labels=action_counts.index,
                    values=action_counts.values,
                    hole=0.3
                )])

                fig_actions.update_layout(
                    title="Trading Actions Distribution",
                    showlegend=True
                )
                st.plotly_chart(fig_actions, use_container_width=True)

                # Trades table
                st.subheader("📋 Trade History")
                st.dataframe(trades_df, use_container_width=True)

        # Comparison with buy-and-hold
        st.subheader("🆚 Buy-and-Hold Comparison")

        agent_type, symbol = selected_result.split('_', 1)

        # Get original data
        if st.session_state.pipeline and symbol in st.session_state.pipeline.processed_data:
            original_data = st.session_state.pipeline.processed_data[symbol]['clean']

            # Calculate buy-and-hold return
            train_size = int(len(original_data) * 0.8)
            test_data = original_data[train_size:]

            if len(test_data) > 0:
                buy_hold_return = (test_data['Close'].iloc[-1] - test_data['Close'].iloc[0]) / test_data['Close'].iloc[0]

                comparison_data = {
                    'Strategy': ['DRL Agent', 'Buy & Hold'],
                    'Return': [performance['total_return'], buy_hold_return],
                    'Sharpe': [performance['sharpe_ratio'], 0]  # Simplified
                }

                comparison_df = pd.DataFrame(comparison_data)

                fig_comparison = px.bar(
                    comparison_df,
                    x='Strategy',
                    y='Return',
                    title='DRL Agent vs Buy & Hold Performance',
                    color='Strategy'
                )

                st.plotly_chart(fig_comparison, use_container_width=True)

                # Performance summary
                if performance['total_return'] > buy_hold_return:
                    st.success(f"🎉 DRL Agent outperformed Buy & Hold by {(performance['total_return'] - buy_hold_return):.2%}")
                else:
                    st.info(f"📊 Buy & Hold outperformed DRL Agent by {(buy_hold_return - performance['total_return']):.2%}")

# ============================================================================
# LIVE TRADING SIMULATION PAGE
# ============================================================================

elif page == "Live Trading Simulation":
    st.header("🔴 Live Trading Simulation")

    # Check if agents are trained
    if not st.session_state.trained_agents:
        st.warning("No trained agents found. Please train some agents first.")
        st.stop()

    # Agent selector
    agent_keys = list(st.session_state.trained_agents.keys())
    selected_agent = st.selectbox("Select Trained Agent", agent_keys)

    if selected_agent:
        st.subheader("🎮 Simulation Controls")

        col1, col2, col3 = st.columns(3)

        with col1:
            initial_balance = st.number_input("Initial Balance ($)", 1000, 100000, 10000)

        with col2:
            transaction_cost = st.number_input("Transaction Cost (%)", 0.0, 1.0, 0.1, format="%.2f") / 100

        with col3:
            simulation_days = st.slider("Simulation Days", 30, 365, 90)

        # Start simulation button
        if st.button("🚀 Start Live Simulation"):
            agent_type, symbol = selected_agent.split('_', 1)

            # Get latest data
            try:
                with st.spinner("Fetching latest market data..."):
                    end_date = datetime.now()
                    start_date = end_date - timedelta(days=simulation_days + 100)  # Extra days for indicators

                    ticker = yf.Ticker(symbol)
                    latest_data = ticker.history(start=start_date, end=end_date)

                    if len(latest_data) > 100:
                        # Apply feature engineering
                        pipeline = st.session_state.pipeline
                        processed_latest = pipeline.calculate_technical_indicators(latest_data)

                        # Take last simulation_days for actual simulation
                        sim_data = processed_latest[-simulation_days:].copy()

                        # Normalize using stored scalers
                        if symbol in pipeline.scalers:
                            price_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
                            tech_cols = [col for col in sim_data.columns if col not in price_cols]

                            # Note: In production, you'd want to update scalers with recent data
                            # For demo, we'll use the stored scalers
                            sim_data_norm = sim_data.copy()

                            # Create simulation environment
                            sim_env = TradingEnvironment(sim_data_norm, initial_balance, transaction_cost)

                            # Run simulation
                            agent_result = st.session_state.trained_agents[selected_agent]
                            agent = agent_result['agent']

                            # Simulate trading
                            state = sim_env.reset()
                            done = False

                            portfolio_values = [initial_balance]
                            actions_taken = []

                            while not done:
                                if agent_type == 'DQN':
                                    agent.epsilon = 0  # No exploration in simulation
                                    action = agent.act(state)
                                else:  # PPO
                                    action, _ = agent.act(state)

                                actions_taken.append(action)
                                state, reward, done, info = sim_env.step(action)
                                portfolio_values.append(info['net_worth'])

                            # Display results
                            final_performance = sim_env.get_performance_metrics()

                            st.subheader("📊 Simulation Results")

                            col1, col2, col3, col4 = st.columns(4)

                            with col1:
                                st.metric("Final Return", f"{final_performance['total_return']:.2%}")

                            with col2:
                                st.metric("Total Trades", final_performance['num_trades'])

                            with col3:
                                st.metric("Final Balance", f"${final_performance['final_balance']:.2f}")

                            with col4:
                                st.metric("Net Worth", f"${final_performance['net_worth']:.2f}")

                            # Portfolio value chart
                            fig_portfolio = go.Figure()

                            # Portfolio value
                            fig_portfolio.add_trace(go.Scatter(
                                x=list(range(len(portfolio_values))),
                                y=portfolio_values,
                                mode='lines',
                                name='Portfolio Value',
                                line=dict(color='blue', width=2)
                            ))

                            # Buy and hold comparison
                            buy_hold_values = [initial_balance * (sim_data['Close'].iloc[i] / sim_data['Close'].iloc[0]) for i in range(len(sim_data))]
                            buy_hold_values = [initial_balance] + buy_hold_values

                            fig_portfolio.add_trace(go.Scatter(
                                x=list(range(len(buy_hold_values))),
                                y=buy_hold_values,
                                mode='lines',
                                name='Buy & Hold',
                                line=dict(color='red', width=2, dash='dash')
                            ))

                            fig_portfolio.update_layout(
                                title=f"Portfolio Performance - {symbol}",
                                xaxis_title="Days",
                                yaxis_title="Portfolio Value ($)",
                                hovermode='x unified'
                            )

                            st.plotly_chart(fig_portfolio, use_container_width=True)

                            # Action distribution
                            action_names = {0: 'Hold', 1: 'Buy', 2: 'Sell'}
                            action_counts = pd.Series(actions_taken).value_counts()
                            action_labels = [action_names.get(i, f'Action {i}') for i in action_counts.index]

                            fig_actions = go.Figure(data=[go.Pie(
                                labels=action_labels,
                                values=action_counts.values,
                                hole=0.3
                            )])

                            fig_actions.update_layout(title="Trading Actions Distribution")
                            st.plotly_chart(fig_actions, use_container_width=True)

                            # Trading log
                            if sim_env.trades:
                                st.subheader("📝 Trading Log")
                                trades_df = pd.DataFrame(sim_env.trades)
                                trades_df['date'] = sim_data.index[trades_df['step']].strftime('%Y-%m-%d')
                                st.dataframe(trades_df[['date', 'action', 'shares', 'price']], use_container_width=True)

                    else:
                        st.error("Insufficient recent data for simulation.")

            except Exception as e:
                st.error(f"Error in simulation: {str(e)}")

# ============================================================================
# MODEL COMPARISON PAGE
# ============================================================================
elif page == "Model Comparison":
    st.header("🔍 Model Comparison")

    # Check if results exist
    if not st.session_state.results:
        st.warning("No trained agents found. Please train some agents first.")
        st.stop()

    # Group results by agent type and symbol
    dqn_results = {}
    ppo_results = {}

    for key, result in st.session_state.results.items():
        try:
            agent_type, symbol = key.split('_', 1)
            if agent_type == 'DQN':
                dqn_results[symbol] = result['performance']
            elif agent_type == 'PPO':
                ppo_results[symbol] = result['performance']
        except:
            continue

    # Create comparison data
    comparison_data = []

    # Get common symbols
    common_symbols = set(dqn_results.keys()) & set(ppo_results.keys())

    if not common_symbols:
        st.warning("No common symbols found for comparison. Please train both DQN and PPO agents on the same symbols.")
        st.stop()

    for symbol in common_symbols:
        dqn_perf = dqn_results[symbol]
        ppo_perf = ppo_results[symbol]

        comparison_data.extend([
            {
                'Agent': 'DQN',
                'Symbol': symbol,
                'Return': dqn_perf['total_return'],
                'Sharpe': dqn_perf['sharpe_ratio'],
                'Trades': dqn_perf['num_trades'],
                'Max_Drawdown': dqn_perf['max_drawdown']
            },
            {
                'Agent': 'PPO',
                'Symbol': symbol,
                'Return': ppo_perf['total_return'],
                'Sharpe': ppo_perf['sharpe_ratio'],
                'Trades': ppo_perf['num_trades'],
                'Max_Drawdown': ppo_perf['max_drawdown']
            }
        ])

    # Create DataFrame
    comparison_df = pd.DataFrame(comparison_data)

    # Plot comparison by return
    st.subheader("📊 Return Comparison")
    fig_return = px.bar(
        comparison_df,
        x="Symbol",
        y="Return",
        color="Agent",
        barmode="group",
        text=comparison_df["Return"].apply(lambda x: f"{x:.2%}"),
        title="DQN vs PPO: Total Return by Symbol"
    )
    fig_return.update_layout(yaxis_tickformat='.0%', xaxis_title="Asset", yaxis_title="Total Return")
    st.plotly_chart(fig_return, use_container_width=True)

    # Plot comparison by Sharpe Ratio
    st.subheader("📈 Sharpe Ratio Comparison")
    fig_sharpe = px.bar(
        comparison_df,
        x="Symbol",
        y="Sharpe",
        color="Agent",
        barmode="group",
        text=comparison_df["Sharpe"].round(2),
        title="DQN vs PPO: Sharpe Ratio by Symbol"
    )
    fig_sharpe.update_layout(xaxis_title="Asset", yaxis_title="Sharpe Ratio")
    st.plotly_chart(fig_sharpe, use_container_width=True)

    # Summary statistics
    st.subheader("📊 Summary Statistics")

    col1, col2 = st.columns(2)

    with col1:
        st.markdown("**DQN Performance:**")
        dqn_avg_return = comparison_df[comparison_df['Agent'] == 'DQN']['Return'].mean()
        dqn_avg_sharpe = comparison_df[comparison_df['Agent'] == 'DQN']['Sharpe'].mean()
        st.write(f"Average Return: {dqn_avg_return:.2%}")
        st.write(f"Average Sharpe: {dqn_avg_sharpe:.3f}")

    with col2:
        st.markdown("**PPO Performance:**")
        ppo_avg_return = comparison_df[comparison_df['Agent'] == 'PPO']['Return'].mean()
        ppo_avg_sharpe = comparison_df[comparison_df['Agent'] == 'PPO']['Sharpe'].mean()
        st.write(f"Average Return: {ppo_avg_return:.2%}")
        st.write(f"Average Sharpe: {ppo_avg_sharpe:.3f}")

    # Tabular view
    st.subheader("📋 Detailed Comparison Table")
    display_df = comparison_df.copy()
    display_df["Return"] = display_df["Return"].apply(lambda x: f"{x:.2%}")
    display_df["Sharpe"] = display_df["Sharpe"].round(3)
    display_df["Max_Drawdown"] = display_df["Max_Drawdown"].apply(lambda x: f"{x:.2%}")

    st.dataframe(display_df, use_container_width=True)