Skip to content

Commit

Permalink
Compare PPO with spinning up (#579)
Browse files Browse the repository at this point in the history
* [wip] compare PPO

* PPO matching

* Revert unnecessary changes

* Minor

* Minor
  • Loading branch information
lihuoran committed Feb 9, 2023
1 parent b05c849 commit ab5e675
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 66 deletions.
45 changes: 36 additions & 9 deletions maro/rl/training/algorithms/base/ac_ppo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,43 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
# Preprocess advantages
states = ndarray_to_tensor(batch.states, device=self._device) # s
actions = ndarray_to_tensor(batch.actions, device=self._device) # a
terminals = ndarray_to_tensor(batch.terminals, device=self._device)
next_states = ndarray_to_tensor(batch.next_states, device=self._device)
if self._is_discrete_action:
actions = actions.long()

with torch.no_grad():
self._v_critic_net.eval()
self._policy.eval()
values = self._v_critic_net.v_values(states).detach().cpu().numpy()
values = np.concatenate([values, np.zeros(1)])
rewards = np.concatenate([batch.rewards, np.zeros(1)])
deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
batch.returns = discount_cumsum(rewards, self._reward_discount)[:-1]
batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam)

batch.returns = np.zeros(batch.size, dtype=np.float32)
batch.advantages = np.zeros(batch.size, dtype=np.float32)
i = 0
while i < batch.size:
j = i
while j < batch.size - 1 and not terminals[j]:
j += 1
last_val = (
0.0
if terminals[j]
else self._v_critic_net.v_values(
next_states[j].unsqueeze(dim=0),
)
.detach()
.cpu()
.numpy()
.item()
)

cur_values = np.append(values[i : j + 1], last_val)
cur_rewards = np.append(batch.rewards[i : j + 1], last_val)
# delta = r + gamma * v(s') - v(s)
cur_deltas = cur_rewards[:-1] + self._reward_discount * cur_values[1:] - cur_values[:-1]
batch.returns[i : j + 1] = discount_cumsum(cur_rewards, self._reward_discount)[:-1]
batch.advantages[i : j + 1] = discount_cumsum(cur_deltas, self._reward_discount * self._lam)

i = j + 1

if self._clip_ratio is not None:
batch.old_logps = self._policy.get_states_actions_logps(states, actions).detach().cpu().numpy()
Expand Down Expand Up @@ -291,21 +316,23 @@ def train_step(self) -> None:
assert isinstance(self._ops, ACBasedOps)

batch = self._get_batch()
for _ in range(self._params.grad_iters):
self._ops.update_critic(batch)

for _ in range(self._params.grad_iters):
early_stop = self._ops.update_actor(batch)
if early_stop:
break

for _ in range(self._params.grad_iters):
self._ops.update_critic(batch)

async def train_step_as_task(self) -> None:
assert isinstance(self._ops, RemoteOps)

batch = self._get_batch()
for _ in range(self._params.grad_iters):
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))

for _ in range(self._params.grad_iters):
if self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)): # early stop
break

for _ in range(self._params.grad_iters):
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
62 changes: 11 additions & 51 deletions maro/rl/training/replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,18 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray:
raise NotImplementedError

@abstractmethod
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Generate a list of indexes that can be used to retrieve items from the replay memory.
Args:
batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience
item is present are returned.
forbid_last (bool, default=False): Whether the latest element is allowed to be sampled.
If this is true, the last index will always be excluded from the result.
Returns:
indexes (np.ndarray): The list of indexes.
"""
raise NotImplementedError

@abstractmethod
def get_last_index(self) -> int:
"""Get the index of the latest element in the memory.
Returns:
index (int): The index of the latest element in the memory.
"""
raise NotImplementedError


class RandomIndexScheduler(AbsIndexScheduler):
"""Index scheduler that returns random indexes when sampling.
Expand Down Expand Up @@ -93,14 +82,11 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray:
self._size = min(self._size + batch_size, self._capacity)
return indexes

def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}"
assert self._size > 0, "Cannot sample from an empty memory."
return np.random.choice(self._size, size=batch_size, replace=True)

def get_last_index(self) -> int:
raise NotImplementedError


class FIFOIndexScheduler(AbsIndexScheduler):
"""First-in-first-out index scheduler.
Expand Down Expand Up @@ -135,19 +121,15 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray:
self._head = (self._head + overwrite) % self._capacity
return self.get_put_indexes(batch_size)

def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
tmp = self._tail if not forbid_last else (self._tail - 1) % self._capacity
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
indexes = (
np.arange(self._head, tmp)
if tmp > self._head
else np.concatenate([np.arange(self._head, self._capacity), np.arange(tmp)])
np.arange(self._head, self._tail)
if self._tail > self._head
else np.concatenate([np.arange(self._head, self._capacity), np.arange(self._tail)])
)
self._head = tmp
self._head = self._tail
return indexes

def get_last_index(self) -> int:
return (self._tail - 1) % self._capacity


class AbsReplayMemory(object, metaclass=ABCMeta):
"""Abstract replay memory class with basic interfaces.
Expand Down Expand Up @@ -176,9 +158,9 @@ def _get_put_indexes(self, batch_size: int) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_put_indexes(batch_size)

def _get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_sample_indexes(batch_size, forbid_last)
return self._idx_scheduler.get_sample_indexes(batch_size)


class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
Expand Down Expand Up @@ -273,7 +255,7 @@ def sample(self, batch_size: int = None) -> TransitionBatch:
Returns:
batch (TransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
indexes = self._get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)

def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
Expand All @@ -298,10 +280,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
old_logps=self._old_logps[indexes],
)

@abstractmethod
def _get_forbid_last(self) -> bool:
raise NotImplementedError


class RandomReplayMemory(ReplayMemory):
def __init__(
Expand All @@ -318,15 +296,11 @@ def __init__(
RandomIndexScheduler(capacity, random_overwrite),
)
self._random_overwrite = random_overwrite
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)

@property
def random_overwrite(self) -> bool:
return self._random_overwrite

def _get_forbid_last(self) -> bool:
return False


class FIFOReplayMemory(ReplayMemory):
def __init__(
Expand All @@ -342,9 +316,6 @@ def __init__(
FIFOIndexScheduler(capacity),
)

def _get_forbid_last(self) -> bool:
return not self._terminals[self._idx_scheduler.get_last_index()]


class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
"""In-memory experience storage facility for a multi trainer.
Expand Down Expand Up @@ -446,7 +417,7 @@ def sample(self, batch_size: int = None) -> MultiTransitionBatch:
Returns:
batch (MultiTransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
indexes = self._get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)

def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
Expand All @@ -470,10 +441,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
next_agent_states=[state[indexes] for state in self._next_agent_states],
)

@abstractmethod
def _get_forbid_last(self) -> bool:
raise NotImplementedError


class RandomMultiReplayMemory(MultiReplayMemory):
def __init__(
Expand All @@ -492,15 +459,11 @@ def __init__(
agent_states_dims,
)
self._random_overwrite = random_overwrite
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)

@property
def random_overwrite(self) -> bool:
return self._random_overwrite

def _get_forbid_last(self) -> bool:
return False


class FIFOMultiReplayMemory(MultiReplayMemory):
def __init__(
Expand All @@ -517,6 +480,3 @@ def __init__(
FIFOIndexScheduler(capacity),
agent_states_dims,
)

def _get_forbid_last(self) -> bool:
return not self._terminals[self._idx_scheduler.get_last_index()]
2 changes: 1 addition & 1 deletion tests/rl/gym_wrapper/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
env_conf = {
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
"start_tick": 0,
"durations": 5000,
"durations": 1000,
"options": {
"random_seed": None,
},
Expand Down
3 changes: 2 additions & 1 deletion tests/rl/gym_wrapper/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def post_collect(self, info_list: list, ep: int) -> None:
self.metrics.update(cur)
# clear validation metrics
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
self._sample_rewards.clear()

def post_evaluate(self, info_list: list, ep: int) -> None:
cur = {
Expand All @@ -83,5 +84,5 @@ def post_evaluate(self, info_list: list, ep: int) -> None:
"val/avg_reward": np.mean([r for _, r in self._eval_rewards]),
"val/avg_n_steps": np.mean([n for n, _ in self._eval_rewards]),
}
self._eval_rewards.clear()
self.metrics.update(cur)
self._eval_rewards.clear()
4 changes: 2 additions & 2 deletions tests/rl/tasks/ac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler

actor_net_conf = {
"hidden_dims": [64, 64],
"hidden_dims": [64, 32],
"activation": torch.nn.Tanh,
}
critic_net_conf = {
"hidden_dims": [64, 64],
"hidden_dims": [64, 32],
"activation": torch.nn.Tanh,
}
actor_learning_rate = 3e-4
Expand Down
2 changes: 2 additions & 0 deletions tests/rl/tasks/ppo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer:
return PPOTrainer(
name=name,
reward_discount=0.99,
replay_memory_capacity=4000,
batch_size=4000,
params=PPOParams(
get_v_critic_net_func=lambda: MyVCriticNet(state_dim),
grad_iters=80,
Expand Down
4 changes: 2 additions & 2 deletions tests/rl/tasks/ppo/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ scenario_path: "tests/rl/tasks/ppo"
log_path: "tests/rl/log/ppo"
main:
num_episodes: 1000
num_steps: null
num_steps: 4000
eval_schedule: 5
num_eval_episodes: 10
min_n_sample: 5000
min_n_sample: 1
logging:
stdout: INFO
file: DEBUG
Expand Down

0 comments on commit ab5e675

Please sign in to comment.