In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [4]:
trajectory_file_path = 'rl_trajectories.pkl'

In [5]:
# This is our Deep Learning model. It's a simple Multi-Layer Perceptron (MLP)
# that takes a state and outputs a Q-value for each possible action.

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, state):
        return self.network(state)


In [2]:
# This is our Conservative Q-Learning (CQL) agent

class CQLAgent:
    def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, cql_alpha=5.0):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.cql_alpha = cql_alpha

        # Main Q-Network
        self.q_network = QNetwork(state_dim, action_dim)

        # Target Q-Network for stability
        self.target_q_network = QNetwork(state_dim, action_dim)
        self.target_q_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

    def train(self, batch):
        states = torch.FloatTensor(np.array(batch['state'].tolist()))
        actions = torch.LongTensor(batch['action'].tolist()).unsqueeze(1)
        rewards = torch.FloatTensor(batch['reward'].tolist()).unsqueeze(1)
        next_states = torch.FloatTensor(np.array(batch['next_state'].tolist()))
        terminals = torch.FloatTensor(batch['terminal'].tolist()).unsqueeze(1)

        # Standard Q-Learning Loss (Bellman Error)
        # Get Q-values for the actions that were actually taken in the dataset
        q_values = self.q_network(states).gather(1, actions)

        # Get the value of the next state from the target network
        with torch.no_grad():
            next_q_values = self.target_q_network(next_states).max(1)[0].unsqueeze(1)
            target_q_values = rewards + (1 - terminals) * self.gamma * next_q_values

        q_loss = nn.MSELoss()(q_values, target_q_values)

        # CQL Conservative Loss
        # It penalizes Q-values for actions that were NOT in the dataset
        # forcing the model to be "conservative."

        all_q_values = self.q_network(states)

        logsumexp_q = torch.logsumexp(all_q_values, dim=1, keepdim=True)

        dataset_q_values = q_values

        # The CQL loss encourages the Q-values of actions in the dataset to be high,
        # while pushing down the Q-values of other actions.
        cql_loss = (logsumexp_q - dataset_q_values).mean()

        # The final loss is a combination of the standard Q-loss and the CQL penalty
        total_loss = q_loss + self.cql_alpha * cql_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        return total_loss.item()

    def update_target_network(self, tau=0.005):
        for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)



In [6]:
# The Training Loop
print("Starting Model Training")

# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 1000

dataset = pd.read_pickle(trajectory_file_path)

state_dim = len(dataset['state'].iloc[0])
action_dim = dataset['action'].max() + 1

print(f"Detected State Dimension: {state_dim}")
print(f"Detected Action Dimension: {action_dim}")

agent = CQLAgent(state_dim=state_dim, action_dim=action_dim)

for epoch in range(EPOCHS):
    batch = dataset.sample(n=BATCH_SIZE)

    loss = agent.train(batch)

    agent.update_target_network()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {loss:.4f}")

print("\n--- Training Complete ---")

Starting Model Training
Detected State Dimension: 10
Detected Action Dimension: 5
Epoch 100/1000, Loss: 36.0201
Epoch 200/1000, Loss: 28.9910
Epoch 300/1000, Loss: 28.0143
Epoch 400/1000, Loss: 30.0036
Epoch 500/1000, Loss: 28.8302
Epoch 600/1000, Loss: 26.0545
Epoch 700/1000, Loss: 26.4429
Epoch 800/1000, Loss: 22.3608
Epoch 900/1000, Loss: 28.2073
Epoch 1000/1000, Loss: 28.5363

--- Training Complete ---


In [7]:
model_save_path = "cql_fincoach_model.pth"
torch.save(agent.q_network.state_dict(), model_save_path)

print(f"\nModel saved successfully to '{model_save_path}'")


Model saved successfully to 'cql_fincoach_model.pth'
