In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from pymoo.algorithms.soo.nonconvex.pso import PSO
from pymoo.core.population import Population
from pymoo.core.problem import Problem
import gymnasium as gym
from pymoo.optimize import minimize

In [13]:
from collections import deque

class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state):
        self.buffer.append((state, action, reward, next_state))

    def sample(self, batch_size: int, random_state=None):
        rng = np.random.default_rng(random_state)
        indices = rng.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, next_states = zip(*[self.buffer[i] for i in indices])
        return np.asarray(states), np.asarray(actions), np.asarray(rewards), np.asarray(next_states)

    def __len__(self):
        return len(self.buffer)

In [14]:
class Actor(nn.Module):
  def __init__(self, state_dim: int, action_dim: int):
    super(Actor, self).__init__()
    self.l1 = nn.Linear(state_dim, 64)
    self.relu = nn.ReLU()
    self.l2 = nn.Linear(64, 32)
    self.l3 = nn.Linear(32, action_dim)

  def forward(self, state):
    x = self.l1(state)
    x = self.relu(x)
    x = self.l2(x)
    x = self.relu(x)
    action_probs = F.softmax(self.l3(x), dim=-1)
    return action_probs

class Critic(nn.Module):
  def __init__(self, state_dim: int):
    super(Critic, self).__init__()
    self.l1 = nn.Linear(state_dim, 128)
    self.relu = nn.ReLU()
    self.l2 = nn.Linear(128, 64)
    self.l3 = nn.Linear(64, 1)

  def forward(self, state):
    x = self.l1(state)
    x = self.relu(x)
    x = self.l2(x)
    x = self.relu(x)
    value = self.l3(x)
    return value

class ActorCritic:
  def __init__(self, state_dim: int, action_dim: int, device: torch.device):
    self.actor = Actor(state_dim, action_dim).to(device)
    self.critic = Critic(state_dim).to(device)
    self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=0.001)
    self.device = device

  def forward(self, state):
    return self.actor(state), self.critic(state)

  def select_action(self, state):
    state = torch.FloatTensor(state)
    action_probs = self.actor(state)
    dist = torch.distributions.Categorical(action_probs)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action.item(), log_prob
  
  def update_critic(self, replay_buffer: ReplayBuffer, gamma: float = 0.99, batch_size: int = 128, random_state=None):
    states, actions, rewards, next_states = replay_buffer.sample(batch_size=batch_size, random_state=random_state)

    states = torch.FloatTensor(states).to(self.device)
    actions = torch.FloatTensor(actions).to(self.device)
    rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
    next_states = torch.FloatTensor(next_states).to(self.device)
    
    current_state_values = self.critic(states)
    next_state_values = self.critic(next_states)
    target_values = rewards + gamma * next_state_values
  
    loss = F.mse_loss(current_state_values, target_values)
    self.critic_optimizer.zero_grad()
    loss.backward()
    self.critic_optimizer.step()


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("LunarLander-v3")
action_dim = env.action_space.n
observation_dim = env.observation_space.shape[0]

action_dim, observation_dim

(np.int64(4), 8)

In [16]:
agents = [
    ActorCritic(state_dim=observation_dim, action_dim=action_dim, device=device)
    for _ in range(25)
]

In [17]:
def run_episode(agent: ActorCritic, env: gym.Env, replay_buffer: ReplayBuffer):
    """evaluate fitness of actor's policy on an environment"""
    total_reward = 0.0
    steps = 0
    observation, _ = env.reset(seed=42) # initial state
    terminated = False
    truncated = False
    while not (terminated or truncated):
        action, _ = agent.select_action(observation)
        new_observation, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        replay_buffer.push(observation, action, reward, new_observation)
        observation = new_observation
        steps += 1
    return total_reward, steps

class TheProblem(Problem):
    def __init__(self, env, agents: list[ActorCritic], replay_buffer: ReplayBuffer, device: torch.device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.env = env
        self.agents = agents
        self.replay_buffer = replay_buffer
        self.device = device

    def _evaluate(self, X, out, *args, **kwargs):
        """X is the set of solutions, not just one solution"""
        F = []
        steps_list = []
        for agent, x in zip(self.agents, X):
            vector_to_parameters(torch.FloatTensor(x).to(self.device), agent.actor.parameters())
            total_reward, steps = run_episode(agent=agent, env=self.env, replay_buffer=self.replay_buffer)
            F.append(-total_reward)
            steps_list.append(steps)
        out["F"] = F
        out["steps"] = steps_list

In [18]:
vector_encoded_actors = np.asarray([
  parameters_to_vector(agent.actor.parameters()).detach().cpu().numpy()
  for agent in agents
])

pso = PSO(
  pop_size=len(vector_encoded_actors),
  sampling=Population.new(X=vector_encoded_actors),
  adaptive=False,
  pertube_best=False,
  seed=0,
)

In [19]:
replay_buffer = ReplayBuffer(capacity=10_000)

problem = TheProblem(env=env, agents=agents, replay_buffer=replay_buffer, n_var=vector_encoded_actors[0].shape[0], xl=-1.0, xu=1.0, device=device)

In [27]:
# res = minimize(problem,
#                 pso,
#                 ('n_gen', 200),
#                 seed=1,
#                 verbose=True)

In [21]:
def evaluate_population(X, agents, env, replay_buffer, device):
    F = []
    steps_list = []
    for x, agent in zip(X, agents):
        vector_to_parameters(torch.FloatTensor(x).to(device), agent.actor.parameters())
        total_reward, steps = run_episode(agent=agent, env=env, replay_buffer=replay_buffer)
        F.append(-total_reward)
        steps_list.append(steps)
    return np.asarray(F), np.asarray(steps_list)

In [26]:
pso.setup(problem)
MAX_TIMESTEPS = 10_000
t = 0
while t < MAX_TIMESTEPS:
    pop = pso.ask()
    pop = pso.evaluator.eval(problem, pop, algorithm=pso)
    t += np.sum(pop.get("steps"))
    pso.tell(pop)

result = pso.result()

In [23]:
# for epoch in range(10):
#     res = minimize(problem,
#                 pso,
#                 ('n_gen', 10),
#                 seed=1,
#                 verbose=True)
#     updated_vector_encoded_actors = res.pop.get("X")
#     for agent, vector_encoded_actor in zip(agents, updated_vector_encoded_actors):
#         vector_to_parameters(torch.FloatTensor(vector_encoded_actor).to(device), agent.actor.parameters())
#         agent.update_critic(replay_buffer=replay_buffer, batch_size=128)
    