In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from pymoo.algorithms.soo.nonconvex.pso import PSO
from pymoo.core.population import Population
from pymoo.core.problem import Problem
import gymnasium as gym

In [12]:
import random

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)

In [13]:
from collections import deque


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, action_log_prob, reward, next_state):
        self.buffer.append((state, action, action_log_prob, reward, next_state))

    def sample(self, batch_size: int, random_state=None):
        rng = np.random.default_rng(random_state)
        indices = rng.choice(len(self.buffer), batch_size, replace=False)
        states, actions, action_log_probs, rewards, next_states = zip(
            *[self.buffer[i] for i in indices]
        )
        return (
            np.asarray(states),
            np.asarray(actions),
            action_log_probs,
            np.asarray(rewards),
            np.asarray(next_states),
        )

    def __len__(self):
        return len(self.buffer)

In [14]:
class Actor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 64)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(64, 32)
        self.l3 = nn.Linear(32, action_dim)

    def forward(self, state):
        x = self.l1(state)
        x = self.relu(x)
        x = self.l2(x)
        x = self.relu(x)
        action_probs = F.softmax(self.l3(x), dim=-1)
        return action_probs


class Critic(nn.Module):
    def __init__(self, state_dim: int):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim, 128)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(128, 64)
        self.l3 = nn.Linear(64, 1)

    def forward(self, state):
        x = self.l1(state)
        x = self.relu(x)
        x = self.l2(x)
        x = self.relu(x)
        value = self.l3(x)
        return value


class ActorCritic:
    def __init__(self, state_dim: int, action_dim: int, device: torch.device):
        self.actor = Actor(state_dim, action_dim).to(device)
        self.critic = Critic(state_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=0.001)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=0.001)
        self.device = device

    def forward(self, state):
        return self.actor(state), self.critic(state)

    def select_action(self, state):
        state = torch.FloatTensor(state)
        action_probs = self.actor(state)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob

    def update(
        self,
        replay_buffer: ReplayBuffer,
        gamma: float = 0.99,
        batch_size: int = 128,
        random_state=None,
    ):
        states, _, action_log_probs, rewards, next_states = replay_buffer.sample(
            batch_size=batch_size, random_state=random_state
        )

        states = torch.FloatTensor(states, device=self.device)
        action_log_probs = torch.stack(action_log_probs).to(self.device).unsqueeze(1)
        # action_log_probs = torch.FloatTensor(action_log_probs, device=self.device)
        rewards = torch.FloatTensor(rewards, device=self.device).unsqueeze(1)
        next_states = torch.FloatTensor(next_states, device=self.device)

        # --- Critic Update ---
        current_state_values = self.critic(states)
        # Get the next state value prediction from the critic
        # For the target, we don't want gradients to flow through next_state_value
        # if the episode is done, next_state_value is 0 (no future rewards)
        next_state_values = self.critic(next_states)
        # Calculate the TD target (R + gamma * V(s'))
        target_values = rewards + gamma * next_state_values

        # Critic loss: Mean Squared Error between predicted value and TD target
        critic_loss = F.mse_loss(current_state_values, target_values)
        # Perform backpropagation for the critic
        self.critic_optimizer.zero_grad()  # Clear previous gradients
        critic_loss.backward()  # Compute gradients
        self.critic_optimizer.step()  # Update critic network parameters

        # --- Actor Update ---
        # Calculate the Advantage (TD Error): TD_target - V(s)
        # It's crucial to detach target_values and current_state_values here to prevent
        # gradients from flowing back into the Critic network during the Actor's update.
        # The Actor's update should only depend on the value estimate, not train the critic.
        advantages = (target_values - current_state_values).detach()

        # Actor loss: Negative log-probability weighted by the advantage
        # We want to maximize expected reward, so we minimize negative expected reward.
        actor_loss = -action_log_probs * advantages  # .unsqueeze(0)
        # Perform backpropagation for the actor
        self.actor_optimizer.zero_grad()  # Clear previous gradients
        # see the following thread for the reason behind .mean().backward()
        # https://discuss.pytorch.org/t/loss-backward-raises-error-grad-can-be-implicitly-created-only-for-scalar-outputs/12152
        print(action_log_probs.size())
        print(advantages.size())
        mean_loss = actor_loss.sum(dim=0) / batch_size
        mean_loss.backward()  # Compute gradients # FIXME
        self.actor_optimizer.step()  # Update actor network parameters

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("LunarLander-v3")
action_dim = env.action_space.n
observation_dim = env.observation_space.shape[0]

action_dim, observation_dim

(np.int64(4), 8)

In [16]:
agents = [
    ActorCritic(state_dim=observation_dim, action_dim=action_dim, device=device)
    for _ in range(25)
]

In [17]:
def run_episode(agent: ActorCritic, env: gym.Env, replay_buffer: ReplayBuffer):
    """evaluate fitness of actor's policy on an environment"""
    total_reward = 0.0
    steps = 0
    observation, _ = env.reset(seed=42)  # initial state
    terminated = False
    truncated = False
    while not (terminated or truncated):
        action, log_prob = agent.select_action(observation)
        new_observation, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        replay_buffer.push(observation, action, log_prob, reward, new_observation)
        observation = new_observation
        steps += 1
    return total_reward, steps


class TheProblem(Problem):
    def __init__(
        self,
        env,
        agents: list[ActorCritic],
        replay_buffer: ReplayBuffer,
        device: torch.device,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.env = env
        self.agents = agents
        self.replay_buffer = replay_buffer
        self.device = device

    def _evaluate(self, X, out, *args, **kwargs):
        """X is the set of solutions, not just one solution"""
        F = []
        steps_list = []
        for agent, x in zip(self.agents, X):
            vector_to_parameters(
                torch.FloatTensor(x).to(self.device), agent.actor.parameters()
            )
            total_reward, steps = run_episode(
                agent=agent, env=self.env, replay_buffer=self.replay_buffer
            )
            F.append(-total_reward)
            steps_list.append(steps)
        out["F"] = F
        out["steps"] = steps_list

In [18]:
vector_encoded_actors = np.asarray(
    [
        parameters_to_vector(agent.actor.parameters()).detach().cpu().numpy()
        for agent in agents
    ]
)

pso = PSO(
    pop_size=len(vector_encoded_actors),
    sampling=Population.new(X=vector_encoded_actors),
    adaptive=False,
    pertube_best=False,
    seed=0,
)

In [19]:
replay_buffer = ReplayBuffer(capacity=10_000)

problem = TheProblem(
    env=env,
    agents=agents,
    replay_buffer=replay_buffer,
    n_var=vector_encoded_actors[0].shape[0],
    xl=-2.0,
    xu=2.0,
    device=device,
)

In [None]:
pso.setup(problem, verbose=True)
MAX_TIMESTEPS = 10_000 * 100
EXPLORATION_RATIO = 0.25
L = np.zeros(pso.pop_size)
B = np.zeros(pso.pop_size)
t, e, b = (0, 0, 0)
pop = pso.ask()
pop = pso.evaluator.eval(problem, pop, algorithm=pso)
while t < MAX_TIMESTEPS:
    stage = 1 if (t < MAX_TIMESTEPS * EXPLORATION_RATIO) else 2
    pso.tell(pop)
    index_list = ([e] * pso.pop_size) + list(range(pso.pop_size))
    for i in index_list:
        fitness, steps = pso.evaluator.eval(problem, pop[i], algorithm=pso).get(
            "F", "steps"
        )
        fitness = fitness[0]
        t += steps
        L[i] += steps
        if fitness > B[i]:
            B[i] = fitness
            L[i] = 0
            e = i
        if B[i] > np.max(B):
            if stage == 1:
                b = i
            elif stage == 2:
                pop[b] = pop[i].copy()
        if stage == 1:  # (optimize Pi by via RL)
            actor_params = pop[i].X
            agent = agents[i]
            vector_to_parameters(
                torch.FloatTensor(actor_params).to(device), agent.actor.parameters()
            )
            agent.update(replay_buffer=replay_buffer, batch_size=128)
            pop[i].set("X", parameters_to_vector(agent.actor.parameters()))
        elif stage == 2:  # (optimize Pb via RL)
            actor_params = pop[b].X
            agent = agents[b]
            vector_to_parameters(
                torch.FloatTensor(actor_params).to(device), agent.actor.parameters()
            )
            agent.update(replay_buffer=replay_buffer, batch_size=128)
            pop[b].set("X", parameters_to_vector(agent.actor.parameters()))
    if L[e] > L[b] or stage == 2:
        e = b
    # ask-and-tell is inverted because `ask()` does the PSO update
    # which happens at the end of the while-loop according to Two-Stage ERL (TERL)
    pop = pso.ask()

result = pso.result()

n_gen  |  n_eval  |    f     |    S    |    w    |    c1    |    c2    |     f_avg     |     f_min    
     1 |       25 |        - |       - |  0.9000 |  2.00000 |  2.00000 |  2.074611E+02 |  5.898743E+01
torch.Size([128, 1])
torch.Size([128, 1])
torch.Size([128, 1])
torch.Size([128, 1])


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
result.F

array([58.98743138])

In [None]:
# from pymoo.optimize import minimize
#
# res = minimize(problem,
#                 pso,
#                 ('n_gen', 200),
#                 seed=1,
#                 verbose=True)

In [None]:
# for epoch in range(10):
#     res = minimize(problem,
#                 pso,
#                 ('n_gen', 10),
#                 seed=1,
#                 verbose=True)
#     updated_vector_encoded_actors = res.pop.get("X")
#     for agent, vector_encoded_actor in zip(agents, updated_vector_encoded_actors):
#         vector_to_parameters(torch.FloatTensor(vector_encoded_actor).to(device), agent.actor.parameters())
#         agent.update(replay_buffer=replay_buffer, batch_size=128)
