From ab5e6753496cd55b57b42237795a4b390e7d8751 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 9 Feb 2023 19:55:52 +0800 Subject: [PATCH] Compare PPO with spinning up (#579) * [wip] compare PPO * PPO matching * Revert unnecessary changes * Minor * Minor --- .../training/algorithms/base/ac_ppo_base.py | 45 +++++++++++--- maro/rl/training/replay_memory.py | 62 ++++--------------- tests/rl/gym_wrapper/common.py | 2 +- tests/rl/gym_wrapper/env_sampler.py | 3 +- tests/rl/tasks/ac/__init__.py | 4 +- tests/rl/tasks/ppo/__init__.py | 2 + tests/rl/tasks/ppo/config.yml | 4 +- 7 files changed, 56 insertions(+), 66 deletions(-) diff --git a/maro/rl/training/algorithms/base/ac_ppo_base.py b/maro/rl/training/algorithms/base/ac_ppo_base.py index 3227437be..544ea93b1 100644 --- a/maro/rl/training/algorithms/base/ac_ppo_base.py +++ b/maro/rl/training/algorithms/base/ac_ppo_base.py @@ -202,6 +202,8 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch: # Preprocess advantages states = ndarray_to_tensor(batch.states, device=self._device) # s actions = ndarray_to_tensor(batch.actions, device=self._device) # a + terminals = ndarray_to_tensor(batch.terminals, device=self._device) + next_states = ndarray_to_tensor(batch.next_states, device=self._device) if self._is_discrete_action: actions = actions.long() @@ -209,11 +211,34 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch: self._v_critic_net.eval() self._policy.eval() values = self._v_critic_net.v_values(states).detach().cpu().numpy() - values = np.concatenate([values, np.zeros(1)]) - rewards = np.concatenate([batch.rewards, np.zeros(1)]) - deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s) - batch.returns = discount_cumsum(rewards, self._reward_discount)[:-1] - batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam) + + batch.returns = np.zeros(batch.size, dtype=np.float32) + batch.advantages = np.zeros(batch.size, dtype=np.float32) + i = 0 + while i < batch.size: + j = i + while j < batch.size - 1 and not terminals[j]: + j += 1 + last_val = ( + 0.0 + if terminals[j] + else self._v_critic_net.v_values( + next_states[j].unsqueeze(dim=0), + ) + .detach() + .cpu() + .numpy() + .item() + ) + + cur_values = np.append(values[i : j + 1], last_val) + cur_rewards = np.append(batch.rewards[i : j + 1], last_val) + # delta = r + gamma * v(s') - v(s) + cur_deltas = cur_rewards[:-1] + self._reward_discount * cur_values[1:] - cur_values[:-1] + batch.returns[i : j + 1] = discount_cumsum(cur_rewards, self._reward_discount)[:-1] + batch.advantages[i : j + 1] = discount_cumsum(cur_deltas, self._reward_discount * self._lam) + + i = j + 1 if self._clip_ratio is not None: batch.old_logps = self._policy.get_states_actions_logps(states, actions).detach().cpu().numpy() @@ -291,21 +316,23 @@ def train_step(self) -> None: assert isinstance(self._ops, ACBasedOps) batch = self._get_batch() - for _ in range(self._params.grad_iters): - self._ops.update_critic(batch) for _ in range(self._params.grad_iters): early_stop = self._ops.update_actor(batch) if early_stop: break + for _ in range(self._params.grad_iters): + self._ops.update_critic(batch) + async def train_step_as_task(self) -> None: assert isinstance(self._ops, RemoteOps) batch = self._get_batch() - for _ in range(self._params.grad_iters): - self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) for _ in range(self._params.grad_iters): if self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)): # early stop break + + for _ in range(self._params.grad_iters): + self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py index 3e4f573e0..e93847572 100644 --- a/maro/rl/training/replay_memory.py +++ b/maro/rl/training/replay_memory.py @@ -35,29 +35,18 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray: raise NotImplementedError @abstractmethod - def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: """Generate a list of indexes that can be used to retrieve items from the replay memory. Args: batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience item is present are returned. - forbid_last (bool, default=False): Whether the latest element is allowed to be sampled. - If this is true, the last index will always be excluded from the result. Returns: indexes (np.ndarray): The list of indexes. """ raise NotImplementedError - @abstractmethod - def get_last_index(self) -> int: - """Get the index of the latest element in the memory. - - Returns: - index (int): The index of the latest element in the memory. - """ - raise NotImplementedError - class RandomIndexScheduler(AbsIndexScheduler): """Index scheduler that returns random indexes when sampling. @@ -93,14 +82,11 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray: self._size = min(self._size + batch_size, self._capacity) return indexes - def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}" assert self._size > 0, "Cannot sample from an empty memory." return np.random.choice(self._size, size=batch_size, replace=True) - def get_last_index(self) -> int: - raise NotImplementedError - class FIFOIndexScheduler(AbsIndexScheduler): """First-in-first-out index scheduler. @@ -135,19 +121,15 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray: self._head = (self._head + overwrite) % self._capacity return self.get_put_indexes(batch_size) - def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: - tmp = self._tail if not forbid_last else (self._tail - 1) % self._capacity + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: indexes = ( - np.arange(self._head, tmp) - if tmp > self._head - else np.concatenate([np.arange(self._head, self._capacity), np.arange(tmp)]) + np.arange(self._head, self._tail) + if self._tail > self._head + else np.concatenate([np.arange(self._head, self._capacity), np.arange(self._tail)]) ) - self._head = tmp + self._head = self._tail return indexes - def get_last_index(self) -> int: - return (self._tail - 1) % self._capacity - class AbsReplayMemory(object, metaclass=ABCMeta): """Abstract replay memory class with basic interfaces. @@ -176,9 +158,9 @@ def _get_put_indexes(self, batch_size: int) -> np.ndarray: """Please refer to the doc string in AbsIndexScheduler.""" return self._idx_scheduler.get_put_indexes(batch_size) - def _get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray: """Please refer to the doc string in AbsIndexScheduler.""" - return self._idx_scheduler.get_sample_indexes(batch_size, forbid_last) + return self._idx_scheduler.get_sample_indexes(batch_size) class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta): @@ -273,7 +255,7 @@ def sample(self, batch_size: int = None) -> TransitionBatch: Returns: batch (TransitionBatch): The sampled batch. """ - indexes = self._get_sample_indexes(batch_size, self._get_forbid_last()) + indexes = self._get_sample_indexes(batch_size) return self.sample_by_indexes(indexes) def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch: @@ -298,10 +280,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch: old_logps=self._old_logps[indexes], ) - @abstractmethod - def _get_forbid_last(self) -> bool: - raise NotImplementedError - class RandomReplayMemory(ReplayMemory): def __init__( @@ -318,15 +296,11 @@ def __init__( RandomIndexScheduler(capacity, random_overwrite), ) self._random_overwrite = random_overwrite - self._scheduler = RandomIndexScheduler(capacity, random_overwrite) @property def random_overwrite(self) -> bool: return self._random_overwrite - def _get_forbid_last(self) -> bool: - return False - class FIFOReplayMemory(ReplayMemory): def __init__( @@ -342,9 +316,6 @@ def __init__( FIFOIndexScheduler(capacity), ) - def _get_forbid_last(self) -> bool: - return not self._terminals[self._idx_scheduler.get_last_index()] - class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta): """In-memory experience storage facility for a multi trainer. @@ -446,7 +417,7 @@ def sample(self, batch_size: int = None) -> MultiTransitionBatch: Returns: batch (MultiTransitionBatch): The sampled batch. """ - indexes = self._get_sample_indexes(batch_size, self._get_forbid_last()) + indexes = self._get_sample_indexes(batch_size) return self.sample_by_indexes(indexes) def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch: @@ -470,10 +441,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch: next_agent_states=[state[indexes] for state in self._next_agent_states], ) - @abstractmethod - def _get_forbid_last(self) -> bool: - raise NotImplementedError - class RandomMultiReplayMemory(MultiReplayMemory): def __init__( @@ -492,15 +459,11 @@ def __init__( agent_states_dims, ) self._random_overwrite = random_overwrite - self._scheduler = RandomIndexScheduler(capacity, random_overwrite) @property def random_overwrite(self) -> bool: return self._random_overwrite - def _get_forbid_last(self) -> bool: - return False - class FIFOMultiReplayMemory(MultiReplayMemory): def __init__( @@ -517,6 +480,3 @@ def __init__( FIFOIndexScheduler(capacity), agent_states_dims, ) - - def _get_forbid_last(self) -> bool: - return not self._terminals[self._idx_scheduler.get_last_index()] diff --git a/tests/rl/gym_wrapper/common.py b/tests/rl/gym_wrapper/common.py index 0ea43bbb1..6792a8ca8 100644 --- a/tests/rl/gym_wrapper/common.py +++ b/tests/rl/gym_wrapper/common.py @@ -10,7 +10,7 @@ env_conf = { "topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4 "start_tick": 0, - "durations": 5000, + "durations": 1000, "options": { "random_seed": None, }, diff --git a/tests/rl/gym_wrapper/env_sampler.py b/tests/rl/gym_wrapper/env_sampler.py index 591a1b29a..73ac48351 100644 --- a/tests/rl/gym_wrapper/env_sampler.py +++ b/tests/rl/gym_wrapper/env_sampler.py @@ -75,6 +75,7 @@ def post_collect(self, info_list: list, ep: int) -> None: self.metrics.update(cur) # clear validation metrics self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")} + self._sample_rewards.clear() def post_evaluate(self, info_list: list, ep: int) -> None: cur = { @@ -83,5 +84,5 @@ def post_evaluate(self, info_list: list, ep: int) -> None: "val/avg_reward": np.mean([r for _, r in self._eval_rewards]), "val/avg_n_steps": np.mean([n for n, _ in self._eval_rewards]), } - self._eval_rewards.clear() self.metrics.update(cur) + self._eval_rewards.clear() diff --git a/tests/rl/tasks/ac/__init__.py b/tests/rl/tasks/ac/__init__.py index 31d4f8b1c..24cc961fc 100644 --- a/tests/rl/tasks/ac/__init__.py +++ b/tests/rl/tasks/ac/__init__.py @@ -26,11 +26,11 @@ from tests.rl.gym_wrapper.env_sampler import GymEnvSampler actor_net_conf = { - "hidden_dims": [64, 64], + "hidden_dims": [64, 32], "activation": torch.nn.Tanh, } critic_net_conf = { - "hidden_dims": [64, 64], + "hidden_dims": [64, 32], "activation": torch.nn.Tanh, } actor_learning_rate = 3e-4 diff --git a/tests/rl/tasks/ppo/__init__.py b/tests/rl/tasks/ppo/__init__.py index 534623d1a..01207562b 100644 --- a/tests/rl/tasks/ppo/__init__.py +++ b/tests/rl/tasks/ppo/__init__.py @@ -25,6 +25,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer: return PPOTrainer( name=name, reward_discount=0.99, + replay_memory_capacity=4000, + batch_size=4000, params=PPOParams( get_v_critic_net_func=lambda: MyVCriticNet(state_dim), grad_iters=80, diff --git a/tests/rl/tasks/ppo/config.yml b/tests/rl/tasks/ppo/config.yml index 130fdf10d..312d1274c 100644 --- a/tests/rl/tasks/ppo/config.yml +++ b/tests/rl/tasks/ppo/config.yml @@ -6,10 +6,10 @@ scenario_path: "tests/rl/tasks/ppo" log_path: "tests/rl/log/ppo" main: num_episodes: 1000 - num_steps: null + num_steps: 4000 eval_schedule: 5 num_eval_episodes: 10 - min_n_sample: 5000 + min_n_sample: 1 logging: stdout: INFO file: DEBUG