# Rainbow DQN Implementation with NoisyNets for Real Time Stock Prediction

In [1]:
from collections import deque
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import os

Load Data

In [3]:
def collect_stock_data():
    print("Loading pre-downloaded stock data...")
    DATA_FILE_PATH = "data/train_data.csv"
    try:
        stock_data = pd.read_csv(DATA_FILE_PATH, index_col="Date", parse_dates=True)
        print("Data loaded successfully.")
    except FileNotFoundError:
        print(f"Error: Data file not found at {DATA_FILE_PATH}")
        print("Please run the data download script first.")
        exit()
    except Exception as e:
        print(f"Error loading data from file: {e}")
        exit()
    return stock_data

In [4]:
def calculate_technical_indicators(df):
    # Calculate returns
    df["Returns"] = df["Close"].pct_change()

    # Calculate moving averages
    df["SMA_20"] = df["Close"].rolling(window=20).mean()
    df["SMA_50"] = df["Close"].rolling(window=50).mean()

    # Calculate RSI
    delta = df["Close"].diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
    rs = gain / loss
    df["RSI"] = 100 - (100 / (1 + rs))

    # Calculate MACD
    exp1 = df["Close"].ewm(span=12, adjust=False).mean()
    exp2 = df["Close"].ewm(span=26, adjust=False).mean()
    df["MACD"] = exp1 - exp2
    df["Signal_Line"] = df["MACD"].ewm(span=9, adjust=False).mean()

    # Calculate Bollinger Bands
    rolling_mean = df["Close"].rolling(window=20).mean()
    rolling_std = df["Close"].rolling(window=20).std()
    df["BB_middle"] = rolling_mean
    df["BB_upper"] = rolling_mean + (2 * rolling_std)
    df["BB_lower"] = rolling_mean - (2 * rolling_std)

    # Calculate volatility
    df["Volatility"] = df["Returns"].rolling(window=20).std()

    return df

In [5]:
class TradingEnvironment:
    def __init__(self, data: pd.DataFrame, initial_balance=100000):

        if data.isnull().values.any():

            nan_rows = data[data.isnull().any(axis=1)]
            print(
                "Warning: DataFrame passed to TradingEnvironment contains NaN values."
            )
            print("Ensure .dropna() was called after calculating indicators.")
            print("First few rows with NaNs:\n", nan_rows.head())

            raise ValueError(
                "DataFrame contains NaNs. Please clean before passing to Environment."
            )

        self.data = data.copy()
        self.initial_balance = initial_balance

        self.feature_columns = [
            "Returns",
            "SMA_20",
            "SMA_50",
            "RSI",
            "MACD",
            "Signal_Line",
            "BB_upper",
            "BB_lower",
            "Volatility",
        ]
        self.price_column = "Close"

        # --- Verify all required columns are present ---
        required_cols_for_state = self.feature_columns + [self.price_column]
        missing_cols = [
            col for col in required_cols_for_state if col not in self.data.columns
        ]
        if missing_cols:
            raise ValueError(f"Missing required columns in input data: {missing_cols}")

        self.state_size = 3 + len(self.feature_columns)
        print(f"Environment initialized. State size: {self.state_size}")

        self.reset()

    def reset(self):
        """Resets the environment to the initial state."""
        self.balance = self.initial_balance
        self.current_step = 0  # Start from the first row (assuming NaNs are dropped)
        self.position = 0  # Number of shares held
        self.portfolio_value = self.initial_balance
        self.returns_history = []
        # print(f"Environment reset. Starting step: {self.current_step}, Index: {self.data.index[self.current_step]}")
        return self._get_state()

    def _get_state(self):

        if self.current_step >= len(self.data):
            # Should not happen if step logic is correct, but safeguard
            print(
                f"Warning: _get_state called at step {self.current_step} >= data length {len(self.data)}. Using last valid data."
            )
            self.current_step = len(self.data) - 1

        current_data = self.data.iloc[self.current_step]
        current_price = current_data[self.price_column]

        # --- 1. Agent Status Features (Normalized) ---
        position_value = self.position * current_price
        normalized_position_value = (
            position_value / self.portfolio_value
            if self.portfolio_value > 1e-6
            else 0.0
        )  # Avoid div by zero
        normalized_balance = self.balance / self.initial_balance
        normalized_portfolio_value = self.portfolio_value / self.initial_balance

        agent_state = [
            normalized_position_value,
            normalized_balance,
            normalized_portfolio_value,
        ]

        # --- 2. Technical Indicator Features (Scaled/Normalized) ---
        tech_state = []
        # Use a small epsilon to prevent division by zero if price is exactly zero
        epsilon = 1e-8
        safe_price = current_price if abs(current_price) > epsilon else epsilon

        # Iterate through defined feature columns and apply scaling/normalization
        for col in self.feature_columns:
            value = current_data[col]
            if pd.isna(value):
                # This shouldn't happen if data is pre-cleaned, but as a fallback
                print(
                    f"Warning: NaN found in column '{col}' at step {self.current_step}. Replacing with 0."
                )
                value = 0.0

            # Apply scaling/normalization based on the indicator type
            if col == "Returns":
                # Returns are often small, maybe scale slightly? Or keep as is. Let's keep as is for now.
                scaled_value = value
            elif col in [
                "SMA_20",
                "SMA_50",
                "BB_upper",
                "BB_lower",
                "MACD",
                "Signal_Line",
            ]:
                # Normalize price-based indicators relative to the current price
                scaled_value = (value - current_price) / safe_price
            elif col == "RSI":
                # Scale RSI from [0, 100] to [0.0, 1.0]
                scaled_value = value / 100.0
            elif col == "Volatility":
                # Volatility is a std dev (percentage), usually small. Keep as is for now.
                scaled_value = value
            else:
                # Default for any unexpected columns (shouldn't happen with check in init)
                scaled_value = value

            tech_state.append(scaled_value)

        # --- Combine states ---
        state = agent_state + tech_state
        state_np = np.array(state, dtype=np.float32)
        if np.any(np.isnan(state_np)) or np.any(np.isinf(state_np)):

            state_np = np.nan_to_num(
                state_np, nan=0.0, posinf=0.0, neginf=0.0
            )

        if len(state_np) != self.state_size:
            raise RuntimeError(
                f"Internal Error: State size mismatch. Expected {self.state_size}, got {len(state_np)}"
            )

        return state_np

    def step(self, action):
        # Actions: 0 = hold, 1 = buy, 2 = sell

        if self.current_step >= len(self.data) - 2:  # Need current and next price
            # print(f"Attempting to step beyond data bounds (step {self.current_step}). Ending episode.")
            current_state = self._get_state()  # Get the last valid state
            return current_state, 0.0, True 

        current_data = self.data.iloc[self.current_step]
        next_data = self.data.iloc[self.current_step + 1]

        price = current_data[self.price_column]
        next_price = next_data[self.price_column]

        # Handle potential NaN prices more robustly
        if pd.isna(price):
            # Price is NaN: Cannot trade reliably. Treat as forced hold, zero reward.
            # print(f"Warning: NaN price encountered at step {self.current_step}. Forcing hold.")
            reward = 0.0
            done = self.current_step >= len(self.data) - 2  # Check done condition again
            self.current_step += 1  # Move step forward
            # Portfolio value likely shouldn't change if price is NaN
            # self.portfolio_value remains the same
            return self._get_state(), reward, done
        if pd.isna(next_price):
            # Next price is NaN: Can execute trade at current price, but reward/next value is uncertain.
            # Option: use current price for next value calculation (conservative)
            # print(f"Warning: NaN next_price encountered at step {self.current_step+1}. Using current price for value calculation.")
            next_price = price  # Fallback

        initial_portfolio_value = self.portfolio_value

        if action == 1:  # Buy
            if (
                self.position == 0 and self.balance > price
            ):  # Check if enough balance for at least 1 share
                shares_to_buy = self.balance // price
                if shares_to_buy > 0:
                    cost = shares_to_buy * price
                    self.balance -= cost
                    self.position = shares_to_buy
            # If already holding or not enough balance, treat as hold

        elif action == 2:  # Sell
            if self.position > 0:
                revenue = self.position * price
                self.balance += revenue
                self.position = 0
            # If not holding, treat as hold

        # Update portfolio value using next_price (potentially the fallback price)
        portfolio_value = self.balance + self.position * next_price

        # Calculate returns based on the change from initial value for this step
        if initial_portfolio_value > 1e-6:  # Avoid division by zero
            returns = (
                portfolio_value - initial_portfolio_value
            ) / initial_portfolio_value
        else:
            returns = 0.0

        self.returns_history.append(returns)
        self.portfolio_value = (
            portfolio_value  # Update portfolio value tracked by the env
        )

        # Advance step *before* getting the next state
        self.current_step += 1

        # CVaR reward adjustment
        reward = returns
        alpha = 0.05
        min_history_for_cvar = 20
        if len(self.returns_history) >= min_history_for_cvar:
            cvar_penalty_factor = 0.1
            calculated_cvar = self._calculate_cvar(
                self.returns_history[-min_history_for_cvar:], alpha=alpha
            )
            reward = returns - cvar_penalty_factor * abs(calculated_cvar)

        # Check if done (reached the end of data)
        # Now done condition is simpler: if current_step points beyond the last valid index
        done = (
            self.current_step >= len(self.data) - 1
        )  # -1 because step was already incremented

        next_state = self._get_state()  # Get state for the *new* current_step

        return next_state, reward, done

    def _calculate_cvar(self, returns, alpha=0.05):
        if not isinstance(returns, np.ndarray):
            returns = np.array(returns)
        if len(returns) == 0:
            return 0.0
        var = np.percentile(returns, alpha * 100)
        cvar = returns[returns <= var].mean()
        return cvar if not np.isnan(cvar) else 0.0

In [6]:
# Assuming NoisyLinear class is defined as before:
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, std_init=0.5):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init

        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer("weight_epsilon", torch.empty(out_features, in_features))

        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer("bias_epsilon", torch.empty(out_features))

        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / np.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / np.sqrt(self.out_features))

    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    def _scale_noise(self, size):
        x = torch.randn(
            size, device=self.weight_mu.device
        )  # Ensure noise is on correct device
        return x.sign().mul_(x.abs().sqrt_())

    def forward(self, x):
        # Ensure noise tensors are on the same device as parameters/input
        if self.weight_epsilon.device != x.device:
            self.weight_epsilon = self.weight_epsilon.to(x.device)
            self.bias_epsilon = self.bias_epsilon.to(x.device)
            # print(f"Moved noise to {x.device} in NoisyLinear") # Optional debug print

        if self.training:
            # Sample new noise only if training
            self.reset_noise()
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            # Use mean weights/biases during evaluation
            weight = self.weight_mu
            bias = self.bias_mu
        return nn.functional.linear(x, weight, bias)

In [7]:
class RainbowDQN(nn.Module):
    def __init__(
        self,
        state_size,
        action_size,
        num_atoms=51,
        v_min=-10,
        v_max=10,
        hidden_size=128,
        device=None,
    ):
        super(RainbowDQN, self).__init__()

        # device setup
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing RainbowDQN on device: {self.device}")

        # save sizes
        self.state_size = state_size
        self.action_size = action_size
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max

        # support for C51 (distributional)
        self.support = torch.linspace(v_min, v_max, num_atoms).to(self.device)

        # --- Shared Feature Extraction Layers ---
        # Using NoisyLinear for exploration baked into the network
        self.feature_layer = nn.Sequential(
            nn.Linear(state_size, hidden_size),  # First layer can be standard Linear
            nn.ReLU(),
            NoisyLinear(hidden_size, hidden_size),  # Subsequent layers are noisy
            nn.ReLU(),
        ).to(
            self.device
        )  # Ensure layers are moved to the correct device

        # --- Dueling Architecture Streams ---
        # 1. Value Stream: Estimates V(s) - output shape [batch_size, num_atoms]
        self.value_stream = nn.Sequential(
            NoisyLinear(hidden_size, hidden_size // 2),  # Smaller layer for value
            nn.ReLU(),
            NoisyLinear(hidden_size // 2, num_atoms),
        ).to(self.device)

        # 2. Advantage Stream: Estimates A(s, a) - output shape [batch_size, action_size * num_atoms]
        self.advantage_stream = nn.Sequential(
            NoisyLinear(hidden_size, hidden_size // 2),  # Smaller layer for advantage
            nn.ReLU(),
            NoisyLinear(hidden_size // 2, action_size * num_atoms),
        ).to(self.device)

        # Move the entire module to the specified device AFTER initializing layers
        self.to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure input tensor is on the correct device
        if x.device != self.device:
            x = x.to(self.device)

        batch_size = x.size(0)
        
        # 1. Pass through shared feature layer
        features = self.feature_layer(x)  # Shape: [batch_size, hidden_size]

        # 2. Pass features through value and advantage streams
        value_logits = self.value_stream(features)  # Shape: [batch_size, num_atoms]
        advantage_logits = self.advantage_stream(
            features
        )  # Shape: [batch_size, action_size * num_atoms]

        # 3. Reshape streams for combination
        # Reshape value to be broadcastable: [batch_size, 1, num_atoms]
        value_logits = value_logits.view(batch_size, 1, self.num_atoms)
        # Reshape advantage: [batch_size, action_size, num_atoms]
        advantage_logits = advantage_logits.view(
            batch_size, self.action_size, self.num_atoms
        )

        # 4. Combine Value and Advantage streams (Dueling formula applied to logits)
        # Q(s, a) = V(s) + (A(s, a) - mean(A(s, .)))
        # Calculate mean advantage across actions for each atom
        mean_advantage_logits = advantage_logits.mean(
            dim=1, keepdim=True
        )  # Shape: [batch_size, 1, num_atoms]

        # Combine using broadcasting
        q_logits = value_logits + (
            advantage_logits - mean_advantage_logits
        )  # Shape: [batch_size, action_size, num_atoms]

        # 5. Apply Softmax to get the probability distribution over atoms for each action
        # Softmax is applied along the last dimension (atoms)
        dist = torch.softmax(
            q_logits, dim=2
        )  # Shape: [batch_size, action_size, num_atoms]

        return dist

    def reset_noise(self):

        for module in self.modules():
            if isinstance(module, NoisyLinear):
                module.reset_noise()

    def get_q_values(self, state: torch.Tensor) -> torch.Tensor:
        # Ensure support is on the correct device
        if self.support.device != self.device:
            self.support = self.support.to(self.device)

        dist = self.forward(state)
        q_values = (dist * self.support).sum(dim=2) 
        return q_values

In [8]:
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4, n_step=3, gamma=0.99):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.n_step = n_step
        self.gamma = gamma

        self.buffer = []
        self.n_step_buffer = deque(maxlen=n_step)
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0
        self._log_frequency = 5

    def _get_n_step_info(self):
        
        reward, next_state, done = self.n_step_buffer[-1][-3:]

        for transition in reversed(list(self.n_step_buffer)[:-1]):
            r, n_s, d = transition[-3:]
            reward = r + self.gamma * reward * (1 - d)
            if d:
                next_state, done = n_s, d

        state, action = self.n_step_buffer[0][:2]
        return state, action, reward, next_state, done

    def push(self, state, action, reward, next_state, done):
        
        self.n_step_buffer.append((state, action, reward, next_state, done))

        if len(self.n_step_buffer) < self.n_step:
            return

        state, action, reward, next_state, done = self._get_n_step_info()

        # --- LOGGING ---
        # Log the calculated N-step reward periodically to avoid flooding the console
        self._push_count += 1
        if self._push_count % self._log_frequency == 0:
            # Get the sequence of immediate rewards that went into this calculation
            immediate_rewards = [t[2] for t in self.n_step_buffer]
            print(
                f"[Buffer Push Step {self._push_count}] Storing N-Step Transition: "
                f"Immediate Rewards: {[f'{r:.3f}' for r in immediate_rewards]} -> "
                f"N-Step Reward: {reward:.4f} | N-Step Done: {done}"
            )
        # --- END LOGGING ---

        max_priority = self.priorities.max() if len(self.buffer) > 0 else 1.0

        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = (state, action, reward, next_state, done)

        self.priorities[self.position] = max_priority
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.buffer) == 0:
            return None

        probs = self.priorities[: len(self.buffer)] ** self.alpha
        probs /= probs.sum()

        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]

        weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
        weights /= weights.max()

        return samples, indices, weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority

    def __len__(self):
        return len(self.buffer)

In [9]:
class RainbowDQNAgent:
    def __init__(
        self,
        state_size,
        action_size,
        device="cuda" if torch.cuda.is_available() else "cpu",
        # Hyperparameters
        v_min=-50.0,  
        v_max=50.0,  
        num_atoms=101,
        hidden_size=128,
        buffer_capacity=100000,
        batch_size=32,
        gamma=0.99,
        n_step=3,
        per_alpha=0.6,
        per_beta=0.4,
        target_update=1000, 
        learning_rate=1e-3, 
    ):
        self.device = torch.device(device) 
        self.action_size = action_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.n_step = n_step
        self.target_update = target_update
        self.steps_done = 0

        # Distributional RL parameters
        self.v_min = v_min
        self.v_max = v_max
        self.num_atoms = num_atoms
        self.support = torch.linspace(v_min, v_max, num_atoms).to(self.device)
        self.delta_z = (v_max - v_min) / (num_atoms - 1)

        # Create networks with specified parameters
        self.policy_net = RainbowDQN(
            state_size, action_size, num_atoms, v_min, v_max, hidden_size, self.device
        ).to(self.device)
        self.target_net = RainbowDQN(
            state_size, action_size, num_atoms, v_min, v_max, hidden_size, self.device
        ).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        # Create optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)

        # Create replay buffer with N-step returns
        self.memory = PrioritizedReplayBuffer(
            buffer_capacity, per_alpha, per_beta, n_step, gamma
        )

        self.beta_start = per_beta
        self.beta_frames = 100000


    def select_action(self, state):

        with torch.no_grad():
            # Ensure state is a float tensor and on the correct device
            if not isinstance(state, torch.Tensor):
                s = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
            else:
                s = state.float().unsqueeze(0).to(self.device)

            if s.shape[1] != self.policy_net.state_size:
                raise ValueError(
                    f"State shape mismatch in select_action. Expected {self.policy_net.state_size}, got {s.shape[1]}"
                )

            # Get expected Q-values using the helper function
            q_values = self.policy_net.get_q_values(s)  # Shape: [1, action_size]
            # Select action with the highest expected Q-value
            return q_values.argmax(1).item()


    def _categorical_projection(self, next_dist_target: torch.Tensor, rewards: torch.Tensor,
                                dones: torch.Tensor, next_action: torch.Tensor) -> torch.Tensor:
        
        batch_size = next_dist_target.size(0)
        # Ensure inputs are on the correct device
        rewards = rewards.to(self.device)
        dones = dones.to(self.device)
        next_action = next_action.to(self.device)

        # Expand reward, done, and next_action tensors for broadcasting
        rewards = rewards.view(batch_size, 1) # [B, 1]
        dones_mask = (~dones).float().view(batch_size, 1) # [B, 1], 1.0 if not done, 0.0 if done

        # Gather the next state distributions corresponding to the selected best next actions
        next_action_expanded = next_action.view(batch_size, 1, 1).expand(-1, -1, self.num_atoms) # [B, 1, N]
        next_dist_best_action = next_dist_target.gather(1, next_action_expanded).squeeze(1) # [B, N]

        # --- Calculate Projected Atoms for Non-Terminal States ---
        # Tz = R + gamma^N * z * (1 - done)
        Tz = rewards + (self.gamma ** self.n_step) * self.support.view(1, -1) * dones_mask # [B, N]
        # Clip projected atoms to [Vmin, Vmax]
        Tz = Tz.clamp(self.v_min, self.v_max)

        # Calculate the fractional bin indices 'b'
        b = (Tz - self.v_min) / self.delta_z # [B, N]

        # Determine the lower and upper integer bin indices
        lower_bound = b.floor().long() # [B, N]
        upper_bound = b.ceil().long()  # [B, N]

        # --- Distribute Probability Mass (Vectorized using scatter_add_) ---
        # Initialize the target distribution tensor
        target_dist = torch.zeros(batch_size, self.num_atoms, device=self.device) # [B, N]

        # Calculate weights for lower and upper bins based on distance
        # Weight assigned to lower bin = (distance from upper bin) * original_probability
        weight_l = (upper_bound.float() - b) * next_dist_best_action # [B, N]
        # Weight assigned to upper bin = (distance from lower bin) * original_probability
        weight_u = (b - lower_bound.float()) * next_dist_best_action # [B, N]

        target_dist.scatter_add_(1, lower_bound.clamp(0, self.num_atoms - 1), weight_l)
        target_dist.scatter_add_(1, upper_bound.clamp(0, self.num_atoms - 1), weight_u)

        done_indices = torch.where(dones)[0]
        num_done = len(done_indices)

        if num_done > 0:

            Tz_done = rewards[done_indices].clamp(self.v_min, self.v_max) # Shape: [num_done, 1]
            b_done = (Tz_done - self.v_min) / self.delta_z # Shape: [num_done, 1]
            l_done = b_done.floor().long() # Shape: [num_done, 1]
            u_done = b_done.ceil().long() # Shape: [num_done, 1]

            target_dist_done = torch.zeros(num_done, self.num_atoms, device=self.device)

            weight_ld = (u_done.float() - b_done).squeeze(1) # Shape: [num_done]
            weight_ud = (b_done - l_done.float()).squeeze(1) # Shape: [num_done]

            eq_mask = (l_done == u_done).squeeze(1) # Shape: [num_done]
            weight_ld[eq_mask] = 0.0 # Set weights explicitly for clarity where l==u
            weight_ud[eq_mask] = 0.0

            target_dist_done.scatter_(1, l_done.clamp(0, self.num_atoms - 1), weight_ld.unsqueeze(1))
            target_dist_done.scatter_add_(1, u_done.clamp(0, self.num_atoms - 1), weight_ud.unsqueeze(1))

            if torch.any(eq_mask):
                l_done_eq = l_done[eq_mask].clamp(0, self.num_atoms - 1)

                src_ones = torch.ones(l_done_eq.size(0), device=self.device).unsqueeze(1)
                target_dist_done.scatter_(1, l_done_eq, src_ones) 


            # Replace the rows in the main target_dist with these calculated done distributions
            target_dist[done_indices] = target_dist_done

        # Optional: Normalize rows to ensure they sum to 1 due to potential floating point inaccuracies
        # target_dist /= target_dist.sum(dim=1, keepdim=True).clamp(min=1e-8)

        return target_dist

    def optimize_model(self):
        
        if len(self.memory) < self.batch_size:
            return None  # Not enough samples yet


        beta = min(
            1.0,
            self.beta_start
            + self.steps_done * (1.0 - self.beta_start) / self.beta_frames,
        )

        transitions, indices, weights = self.memory.sample(
            self.batch_size
        )
        batch = list(zip(*transitions))

        
        states_np = np.array(batch[0], dtype=np.float32)
        actions_np = np.array(batch[1], dtype=np.int64)
        rewards_np = np.array(batch[2], dtype=np.float32)  # These are N-step rewards
        next_states_np = np.array(batch[3], dtype=np.float32)  # This is state S_{t+N}
        dones_np = np.array(batch[4], dtype=bool)  # Done flags for S_{t+N}

        state_batch = torch.from_numpy(states_np).to(self.device)
        action_batch = torch.from_numpy(actions_np).to(self.device)  # [B]
        reward_batch = torch.from_numpy(rewards_np).to(self.device)  # [B]
        next_state_batch = torch.from_numpy(next_states_np).to(self.device)
        done_batch = torch.from_numpy(dones_np).to(self.device)  # [B]
        weights = (
            torch.from_numpy(np.array(weights, dtype=np.float32))
            .to(self.device)
            .unsqueeze(1)
        )  # [B, 1] for broadcasting loss

        # --- Target Calculation ---
        with torch.no_grad():
            next_dist_target = self.target_net(next_state_batch)  # [B, A, N]
            next_q_target = (next_dist_target * self.support).sum(dim=2)  # [B, A]
            next_action = next_q_target.argmax(dim=1)  # [B]
            
            target_dist = self._categorical_projection(
                next_dist_target,  # Distribution from target net [B, A, N]
                reward_batch,  # N-step Rewards [B]
                done_batch,  # Done flags for S_{t+N} [B]
                next_action,  # Best action in S_{t+N} selected by target net [B]
            )
        current_dist = self.policy_net(state_batch)  # [B, A, N]

        action_batch_expanded = action_batch.view(self.batch_size, 1, 1).expand(
            -1, -1, self.num_atoms
        )  # [B, 1, N]
        current_dist_taken_action = current_dist.gather(
            1, action_batch_expanded
        ).squeeze(
            1
        )  # [B, N]

        # Compute cross-entropy loss between target and current distributions
        # Add small epsilon for numerical stability before log
        loss = -(target_dist * torch.log(current_dist_taken_action + 1e-8)).sum(
            dim=1
        )  # [B]

        # Apply PER weights: loss = loss * weights
        # Loss shape is [B], weights shape is [B, 1], squeeze weights or unsqueeze loss
        weighted_loss = (
            loss * weights.squeeze(1)
        ).mean()  # Calculate the mean weighted loss

        # --- Update Priorities in PER Buffer ---
        new_priorities = loss.abs().detach().cpu().numpy() + 1e-6
        self.memory.update_priorities(indices, new_priorities)

        # --- Optimize Policy Network ---
        self.optimizer.zero_grad()
        weighted_loss.backward()
        # Clip gradients to prevent explosions
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0)
        self.optimizer.step()

        # --- Reset Noise ---
        # Reset noise in NoisyLinear layers (for both networks if target uses them)
        # Important for exploration when using Noisy Nets
        self.policy_net.reset_noise()
        self.target_net.reset_noise()

        # --- Update Target Network --- (Soft update often more stable, but periodic hard update is simpler)
        self.steps_done += 1  # Increment optimization step counter
        if self.steps_done % self.target_update == 0:
            print(f"--- Updating target network at step {self.steps_done} ---")
            self.target_net.load_state_dict(self.policy_net.state_dict())

        return weighted_loss.item()  # Return loss value for logging if needed

In [10]:
# Example modification in train_agent loop:
def train_agent(
    env, agent, num_episodes=1000, max_steps_per_episode=10000
):  # Add max steps
    returns_history = []
    losses_history = []

    for episode in range(num_episodes):
        state = env.reset()
        episode_return = 0
        episode_losses = []
        done = False
        steps_in_episode = 0  # Track steps within episode

        while not done and steps_in_episode < max_steps_per_episode:
            # Select action
            action = agent.select_action(state)

            # Take action
            next_state, reward, done = env.step(action)

            # Store transition in memory (using N-step buffer)
            agent.memory.push(state, action, reward, next_state, done)

            # Move to next state
            state = next_state
            episode_return += reward

            # Optimize model (only if buffer is large enough)
            loss_val = agent.optimize_model()  # optimize_model increments steps_done
            if loss_val is not None:
                episode_losses.append(loss_val)

            steps_in_episode += 1

        returns_history.append(
            env.portfolio_value
        )  # Store final portfolio value or total return
        avg_loss = np.mean(episode_losses) if episode_losses else 0
        losses_history.append(avg_loss)
        print(
            f"Episode {episode + 1}/{num_episodes}, "
            f"Steps: {steps_in_episode}, "
            f"Total Steps: {agent.steps_done}, "
            f"Return: {env.portfolio_value:.2f}, "  # Print final portfolio value
            f"Avg Loss: {avg_loss:.4f}"
        )

    return returns_history, losses_history, agent  # Return losses too

In [None]:
def simulate_agent_on_data(
    agent: RainbowDQNAgent, eval_data: pd.DataFrame, initial_balance: float
) -> tuple[list, list]:
    """
    Simulates the trained agent's trading on evaluation data without learning.
    This is a core part of the evaluation process.

    Args:
        agent: The trained RainbowDQNAgent instance.
        eval_data: Preprocessed DataFrame containing evaluation data.
        initial_balance: Starting balance for the simulation.

    Returns:
        A tuple containing:
            - portfolio_values (list): History of portfolio value at each step.
            - actions_taken (list): History of actions taken (0, 1, or 2).
    """
    print(f"Simulating agent on evaluation data...")
    # Create a temporary environment just for this simulation run
    eval_env = TradingEnvironment(eval_data, initial_balance=initial_balance)
    state = eval_env.reset()
    done = False
    portfolio_values = [initial_balance]  # Start with initial balance
    actions_taken = []

    # Ensure agent is in evaluation mode (disables noise, dropout etc.)
    agent.policy_net.eval()

    while not done:
        # Get action without calculating gradients
        with torch.no_grad():
            action = agent.select_action(state)

        # Step through the environment
        next_state, reward, done = eval_env.step(action)

        # Record results for this step
        portfolio_values.append(eval_env.portfolio_value)
        actions_taken.append(action)
        state = next_state

    # Set agent back to training mode (good practice if reusing agent object)
    agent.policy_net.train()
    print(
        f"Agent simulation complete. Final portfolio value: {portfolio_values[-1]:.2f}"
    )
    return portfolio_values, actions_taken


def simulate_buy_and_hold(
    eval_data: pd.DataFrame, initial_balance: float
) -> tuple[list, float]:
    """
    Simulates a simple Buy and Hold strategy on the evaluation data.

    Args:
        eval_data: Preprocessed DataFrame containing evaluation data (needs 'Close').
        initial_balance: Starting balance for the simulation.

    Returns:
        A tuple containing:
            - portfolio_values (list): History of portfolio value at each step.
            - total_return_pct (float): The total return percentage.
    """
    print("Simulating Buy and Hold strategy...")
    if eval_data.empty:
        print("Warning: Cannot simulate Buy and Hold on empty data.")
        return [initial_balance], 0.0

    start_price = eval_data["Close"].iloc[0]
    end_price = eval_data["Close"].iloc[-1]

    # Calculate number of shares bought at the start
    num_shares = initial_balance // start_price
    cash_left = initial_balance - (num_shares * start_price)

    # Calculate daily portfolio values (value of shares + leftover cash)
    daily_values = (eval_data["Close"] * num_shares) + cash_left
    portfolio_values = [
        initial_balance
    ] + daily_values.tolist()  # Add initial balance at start

    # Calculate final metrics
    final_value = portfolio_values[-1]
    total_return_pct = ((final_value - initial_balance) / initial_balance) * 100

    print(
        f"Buy & Hold simulation complete. Final portfolio value: {final_value:.2f}, Total Return: {total_return_pct:.2f}%"
    )
    return portfolio_values, total_return_pct


def calculate_performance_metrics(
    portfolio_values: list, risk_free_rate: float = 0.0
) -> dict:
    """
    Calculates key performance metrics from a list of portfolio values.

    Args:
        portfolio_values: List of portfolio values over time.
        risk_free_rate: Annual risk-free rate for Sharpe Ratio calculation.

    Returns:
        A dictionary containing: 'Sharpe Ratio', 'Max Drawdown'.
    """
    metrics = {"Sharpe Ratio": 0.0, "Max Drawdown": 0.0}
    if len(portfolio_values) < 2:
        print("Warning: Need at least 2 portfolio values to calculate metrics.")
        return metrics

    values_series = pd.Series(portfolio_values)
    daily_returns = values_series.pct_change().dropna()

    # Sharpe Ratio
    if not daily_returns.empty and daily_returns.std() != 0:
        excess_returns = daily_returns - risk_free_rate / 252
        sharpe = excess_returns.mean() / excess_returns.std()
        metrics["Sharpe Ratio"] = sharpe * np.sqrt(252)  # Annualized

    # Max Drawdown
    cumulative_max = values_series.cummax()
    drawdown = ((values_series - cumulative_max) / cumulative_max) * 100  # In percent
    metrics["Max Drawdown"] = drawdown.min() if not drawdown.empty else 0.0

    return metrics

In [None]:
def calculate_drawdown_series(portfolio_values: list) -> pd.Series:
    """Calculates the historical drawdown series."""
    portfolio_values_series = pd.Series(portfolio_values)
    if len(portfolio_values_series) < 2:
        return pd.Series(dtype=float)
    cumulative_max = portfolio_values_series.cummax()
    drawdown = (
        (portfolio_values_series - cumulative_max) / cumulative_max.replace(0, 1e-9)
    ) * 100  # Avoid div by zero if starts at 0
    return drawdown.fillna(0)  # Fill initial NaN


def plot_portfolio_value(agent_values, bnh_values, initial_balance, index, save_path):
    """Plots Agent vs Buy&Hold Portfolio Value over Time."""
    print(f"Generating Portfolio Value plot...")
    if (
        not agent_values
        or not bnh_values
        or len(agent_values) <= 1
        or len(bnh_values) <= 1
    ):
        print("Skipping portfolio value plot due to insufficient data.")
        return

    fig, ax = plt.subplots(figsize=(12, 7))
    # Plot lines starting from step 1 to align with index
    ax.plot(
        index,
        agent_values[1:],
        label="Agent Portfolio",
        color="tab:blue",
        linewidth=1.5,
    )
    ax.plot(
        index,
        bnh_values[1:],
        label="Buy & Hold Portfolio",
        color="tab:orange",
        linewidth=1.5,
    )
    ax.axhline(y=initial_balance, color="gray", linestyle="--", label="Initial Balance")

    ax.set_xlabel("Date")
    ax.set_ylabel("Portfolio Value")
    ax.set_title("Agent vs. Buy & Hold Portfolio Value (Test Set)")
    ax.legend()
    ax.grid(True, linestyle=":", alpha=0.6)
    ax.ticklabel_format(style="plain", axis="y")
    fig.autofmt_xdate()

    try:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Portfolio Value plot saved to {save_path}")
    except Exception as e:
        print(f"Error saving Portfolio Value plot: {e}")
    plt.close(fig)


def plot_drawdown(agent_values, bnh_values, index, save_path):
    """Plots Agent vs Buy&Hold Drawdown over Time."""
    print(f"Generating Drawdown plot...")
    if (
        not agent_values
        or not bnh_values
        or len(agent_values) <= 1
        or len(bnh_values) <= 1
    ):
        print("Skipping drawdown plot due to insufficient data.")
        return

    agent_drawdown = calculate_drawdown_series(agent_values)
    bnh_drawdown = calculate_drawdown_series(bnh_values)

    fig, ax = plt.subplots(figsize=(12, 7))
    # Plot drawdowns aligning with index
    ax.plot(
        index,
        agent_drawdown[1:],
        label="Agent Drawdown",
        color="tab:red",
        linewidth=1.5,
    )
    ax.plot(
        index,
        bnh_drawdown[1:],
        label="Buy & Hold Drawdown",
        color="tab:purple",
        linewidth=1.5,
    )
    ax.axhline(y=0, color="black", linestyle="-", linewidth=0.7)

    ax.set_xlabel("Date")
    ax.set_ylabel("Drawdown (%)")
    ax.set_title("Agent vs. Buy & Hold Drawdown (Test Set)")
    ax.yaxis.set_major_formatter(mtick.PercentFormatter())
    ax.legend()
    ax.grid(True, linestyle=":", alpha=0.6)
    fig.autofmt_xdate()

    try:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Drawdown plot saved to {save_path}")
    except Exception as e:
        print(f"Error saving Drawdown plot: {e}")
    plt.close(fig)


def plot_actions(eval_data, actions_taken, save_path):
    """Plots Agent Buy/Sell Actions Overlayed on Price."""
    print(f"Generating Actions plot...")
    if eval_data.empty or not actions_taken:
        print("Skipping actions plot due to empty data.")
        return

    # Align actions with the dates in eval_data
    # actions_taken corresponds to decisions made *at the start* of the step
    # Price plotted is usually the closing price *at the end* of the step
    # We plot the action marker on the day the action was decided.
    if len(actions_taken) != len(eval_data):
        print(
            f"Warning: Length mismatch actions ({len(actions_taken)}) vs data ({len(eval_data)}). Adjusting."
        )
        min_len = min(len(actions_taken), len(eval_data))
        eval_data_plot = eval_data.iloc[:min_len]
        actions_plot = np.array(
            actions_taken[:min_len]
        )  # Convert to numpy array for boolean indexing
    else:
        eval_data_plot = eval_data
        actions_plot = np.array(actions_taken)

    buy_indices = np.where(actions_plot == 1)[0]
    sell_indices = np.where(actions_plot == 2)[0]

    buy_signals_dates = eval_data_plot.index[buy_indices]
    sell_signals_dates = eval_data_plot.index[sell_indices]

    fig, ax = plt.subplots(figsize=(14, 7))
    ax.plot(
        eval_data_plot.index,
        eval_data_plot["Close"],
        label="Close Price",
        color="black",
        alpha=0.8,
        linewidth=1,
    )

    if len(buy_signals_dates) > 0:
        ax.plot(
            buy_signals_dates,
            eval_data_plot.loc[buy_signals_dates]["Close"],
            "^",
            markersize=8,
            color="green",
            label="Buy Signal",
            alpha=0.9,
            linestyle="None",
        )
    if len(sell_signals_dates) > 0:
        ax.plot(
            sell_signals_dates,
            eval_data_plot.loc[sell_signals_dates]["Close"],
            "v",
            markersize=8,
            color="red",
            label="Sell Signal",
            alpha=0.9,
            linestyle="None",
        )

    ax.set_xlabel("Date")
    ax.set_ylabel("Price")
    ax.set_title("Agent Trading Actions on Price (Test Set)")
    ax.legend()
    ax.grid(True, linestyle=":", alpha=0.6)
    fig.autofmt_xdate()

    try:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Actions plot saved to {save_path}")
    except Exception as e:
        print(f"Error saving Actions plot: {e}")
    plt.close(fig)


def plot_returns_histogram(agent_values, bnh_values, save_path, bins=50):
    """Plots Histogram of Daily Returns for Agent vs Buy&Hold."""
    print(f"Generating Returns Histogram plot...")
    if len(agent_values) <= 1 or len(bnh_values) <= 1:
        print("Skipping returns histogram plot due to insufficient data.")
        return

    agent_returns = pd.Series(agent_values).pct_change().dropna() * 100  # In percentage
    bnh_returns = pd.Series(bnh_values).pct_change().dropna() * 100  # In percentage

    if agent_returns.empty and bnh_returns.empty:
        print("No returns data to plot histogram.")
        return

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.hist(
        agent_returns,
        bins=bins,
        alpha=0.7,
        label="Agent Daily Returns",
        color="tab:blue",
        density=True,
    )
    ax.hist(
        bnh_returns,
        bins=bins,
        alpha=0.7,
        label="Buy & Hold Daily Returns",
        color="tab:orange",
        density=True,
    )

    ax.set_xlabel("Daily Return (%)")
    ax.set_ylabel("Density")
    ax.set_title("Distribution of Daily Returns (Test Set)")
    ax.xaxis.set_major_formatter(mtick.PercentFormatter())
    ax.legend()
    ax.grid(True, axis="y", linestyle=":", alpha=0.6)

    try:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Returns Histogram plot saved to {save_path}")
    except Exception as e:
        print(f"Error saving Returns Histogram plot: {e}")
    plt.close(fig)


def print_training_charts(returns_history, losses_history, initial_balance, save_path):
    # Generates and saves the plot showing training progress.
    # Renamed from print_charts to avoid confusion with eval plots.
    print("Generating training performance plot...")
    if not returns_history or not losses_history:
        print("Skipping training plot as no training history was generated.")
        return

    cumulative_returns_pct = [
        ((final_value - initial_balance) / initial_balance) * 100
        for final_value in returns_history
    ]

    fig, ax1 = plt.subplots(figsize=(12, 7))
    color_return = "tab:red"
    ax1.set_xlabel("Episode")
    ax1.set_ylabel("Cumulative Return (%)", color=color_return)
    ax1.plot(
        cumulative_returns_pct,
        color=color_return,
        label="Cumulative Return %",
        linewidth=1.5,
    )
    ax1.tick_params(axis="y", labelcolor=color_return)
    ax1.axhline(y=0, color="gray", linestyle="--", label="0% Return")
    ax1.legend(loc="upper left")
    ax1.grid(True, axis="y", linestyle=":", alpha=0.6)

    ax2 = ax1.twinx()
    color_loss = "tab:blue"
    ax2.set_ylabel("Average Loss", color=color_loss)
    ax2.plot(
        losses_history, color=color_loss, alpha=0.7, label="Avg Loss", linewidth=1.5
    )
    ax2.tick_params(axis="y", labelcolor=color_loss)
    ax2.legend(loc="upper right")
    plt.title("Training Performance: Cumulative Return and Average Loss per Episode")
    fig.tight_layout()

    try:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Training plot saved successfully to {save_path}")
    except Exception as e:
        print(f"Error saving training plot: {e}")
    plt.close(fig)


# --- Top-Level Evaluation Orchestration Function ---
def run_evaluation(
    agent: RainbowDQNAgent, eval_data_path: str, initial_balance: float, output_dir: str
):
    """
    Loads evaluation data, runs agent and baseline simulations,
    calculates metrics, prints results, and generates evaluation plots.

    Args:
        agent: The trained agent instance.
        eval_data_path: Path to the evaluation data CSV file.
        initial_balance: Starting balance for evaluation simulations.
        output_dir: Directory to save the evaluation plots.
    """
    print("\n" + "=" * 30 + " STARTING EVALUATION " + "=" * 30)

    # --- Load and Prepare Evaluation Data ---
    try:
        eval_data_raw = pd.read_csv(eval_data_path, index_col="Date", parse_dates=True)
        print(f"Loaded evaluation data from {eval_data_path}")
    except FileNotFoundError:
        print(f"Error: Evaluation data file not found at {eval_data_path}")
        print("Evaluation skipped.")
        return
    except Exception as e:
        print(f"Error loading evaluation data: {e}")
        print("Evaluation skipped.")
        return

    # Calculate indicators and drop NaNs for the evaluation period
    eval_data = calculate_technical_indicators(eval_data_raw.copy())
    eval_data.dropna(inplace=True)

    if eval_data.empty:
        print(
            "Error: Evaluation data is empty after preprocessing. Evaluation skipped."
        )
        return

    print(f"Evaluation data range: {eval_data.index.min()} to {eval_data.index.max()}")
    print(f"Evaluation data shape: {eval_data.shape}")

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # --- Run Simulations ---
    # Agent Simulation
    agent_portfolio_values, agent_actions = simulate_agent_on_data(
        agent, eval_data, initial_balance
    )
    # Buy and Hold Simulation
    bnh_portfolio_values, bnh_total_return_pct = simulate_buy_and_hold(
        eval_data, initial_balance
    )

    # --- Calculate Metrics ---
    # Agent Metrics
    agent_final_value = agent_portfolio_values[-1]
    agent_total_return_pct = (
        (agent_final_value - initial_balance) / initial_balance
    ) * 100
    agent_metrics = calculate_performance_metrics(agent_portfolio_values)

    # BnH Metrics
    bnh_metrics = calculate_performance_metrics(bnh_portfolio_values)

    # --- Print Comparison Table ---
    print("\n--- Evaluation Results Comparison ---")
    print("| Metric                     | Rainbow DQN Agent | Buy-and-Hold Baseline |")
    print("| :------------------------- | :---------------- | :-------------------- |")
    print(
        f"| Total Return               | {agent_total_return_pct: >17.2f}% | {bnh_total_return_pct: >21.2f}% |"
    )
    print(
        f"| Annualized Sharpe Ratio    | {agent_metrics['Sharpe Ratio']: >17.2f} | {bnh_metrics['Sharpe Ratio']: >21.2f} |"
    )
    print(
        f"| Maximum Drawdown (MDD)     | {agent_metrics['Max Drawdown']: >17.2f}% | {bnh_metrics['Max Drawdown']: >21.2f}% |"
    )
    print(
        f"| Initial Balance            | {initial_balance: >17,.0f} | {initial_balance: >21,.0f} |"
    )
    print(
        f"| Final Portfolio Value      | {agent_final_value: >17,.2f} | {bnh_portfolio_values[-1]: >21,.2f} |"
    )

    # --- Generate Evaluation Plots ---
    print("\n--- Generating Evaluation Plots ---")
    eval_index = eval_data.index  # Use dates from processed eval data

    plot_portfolio_value(
        agent_portfolio_values,
        bnh_portfolio_values,
        initial_balance,
        eval_index,
        os.path.join(output_dir, "evaluation_portfolio_value.png"),
    )
    plot_drawdown(
        agent_portfolio_values,
        bnh_portfolio_values,
        eval_index,
        os.path.join(output_dir, "evaluation_drawdown.png"),
    )
    plot_actions(
        eval_data, agent_actions, os.path.join(output_dir, "evaluation_actions.png")
    )
    plot_returns_histogram(
        agent_portfolio_values,
        bnh_portfolio_values,
        os.path.join(output_dir, "evaluation_returns_histogram.png"),
    )

    print("=" * 30 + " EVALUATION FINISHED " + "=" * 30 + "\n")

In [None]:
if __name__ == '__main__':

    # np.random.seed(42)
    # torch.manual_seed(42)
    # random.seed(42)

    # --- Define Paths and Parameters ---
    TRAIN_DATA_PATH = "/home/kartikeya.agrawal_ug25/RL_Final/data/train_data.csv"  # Assumed available via collect_stock_data
    EVAL_DATA_PATH = "/home/kartikeya.agrawal_ug25/RL_Final/data/eval_data.csv"
    OUTPUT_DIR = "/home/kartikeya.agrawal_ug25/RL_Final/output"  # Directory for plots
    NUM_TRAIN_EPISODES = 500  # Number of episodes for training
    INITIAL_BALANCE = 100000  # Starting balance

    # --- Data Loading and Preprocessing (for Training) ---
    # collect_stock_data now loads from TRAIN_DATA_PATH
    train_stock_data = collect_stock_data()
    processed_train_data = calculate_technical_indicators(train_stock_data.copy())
    processed_train_data.dropna(inplace=True)

    print("Training Dataset shape:", processed_train_data.shape)
    print(
        "Training Date range:",
        processed_train_data.index.min(),
        "to",
        processed_train_data.index.max(),
    )

    # --- Create Training Environment and Agent ---
    train_env = TradingEnvironment(
        processed_train_data, initial_balance=INITIAL_BALANCE
    )
    state_size = train_env.state_size
    action_size = 3

    device_name = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device_name}")
    agent = RainbowDQNAgent(state_size, action_size, device=device_name)


    # --- Train the Agent ---
    print(f"Starting training for {NUM_TRAIN_EPISODES} episodes...")
    returns_history, losses_history, agent = train_agent(
        train_env, agent, num_episodes=NUM_TRAIN_EPISODES
    )

    # --- Save Training Plot ---
    os.makedirs(OUTPUT_DIR, exist_ok=True)  # Ensure output dir exists
    training_plot_save_path = os.path.join(
        OUTPUT_DIR, f"training_performance_{NUM_TRAIN_EPISODES}.png"
    )
    print_training_charts(
        returns_history, losses_history, INITIAL_BALANCE, training_plot_save_path
    )

    # --- Run Full Evaluation ---
    # Pass the *trained* agent, path to eval data, initial balance, and output dir
    run_evaluation(agent, EVAL_DATA_PATH, INITIAL_BALANCE, OUTPUT_DIR)

    print("Script finished.")