Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare PPO with spinning up #579

Merged
merged 5 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions maro/rl/training/algorithms/base/ac_ppo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,43 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
# Preprocess advantages
states = ndarray_to_tensor(batch.states, device=self._device) # s
actions = ndarray_to_tensor(batch.actions, device=self._device) # a
terminals = ndarray_to_tensor(batch.terminals, device=self._device)
next_states = ndarray_to_tensor(batch.next_states, device=self._device)
if self._is_discrete_action:
actions = actions.long()

with torch.no_grad():
self._v_critic_net.eval()
self._policy.eval()
values = self._v_critic_net.v_values(states).detach().cpu().numpy()
values = np.concatenate([values, np.zeros(1)])
rewards = np.concatenate([batch.rewards, np.zeros(1)])
deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
batch.returns = discount_cumsum(rewards, self._reward_discount)[:-1]
batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam)

batch.returns = np.zeros(batch.size, dtype=np.float32)
Jinyu-W marked this conversation as resolved.
Show resolved Hide resolved
batch.advantages = np.zeros(batch.size, dtype=np.float32)
i = 0
while i < batch.size:
j = i
while j < batch.size - 1 and not terminals[j]:
j += 1
last_val = (
0.0
if terminals[j]
else self._v_critic_net.v_values(
next_states[j].unsqueeze(dim=0),
)
.detach()
.cpu()
.numpy()
.item()
)

cur_values = np.append(values[i : j + 1], last_val)
cur_rewards = np.append(batch.rewards[i : j + 1], last_val)
# delta = r + gamma * v(s') - v(s)
cur_deltas = cur_rewards[:-1] + self._reward_discount * cur_values[1:] - cur_values[:-1]
batch.returns[i : j + 1] = discount_cumsum(cur_rewards, self._reward_discount)[:-1]
batch.advantages[i : j + 1] = discount_cumsum(cur_deltas, self._reward_discount * self._lam)

i = j + 1

if self._clip_ratio is not None:
batch.old_logps = self._policy.get_states_actions_logps(states, actions).detach().cpu().numpy()
Expand Down Expand Up @@ -291,21 +316,23 @@ def train_step(self) -> None:
assert isinstance(self._ops, ACBasedOps)

batch = self._get_batch()
for _ in range(self._params.grad_iters):
self._ops.update_critic(batch)

for _ in range(self._params.grad_iters):
early_stop = self._ops.update_actor(batch)
if early_stop:
break

for _ in range(self._params.grad_iters):
self._ops.update_critic(batch)

async def train_step_as_task(self) -> None:
assert isinstance(self._ops, RemoteOps)

batch = self._get_batch()
for _ in range(self._params.grad_iters):
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))

for _ in range(self._params.grad_iters):
if self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)): # early stop
break

for _ in range(self._params.grad_iters):
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
62 changes: 11 additions & 51 deletions maro/rl/training/replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,18 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray:
raise NotImplementedError

@abstractmethod
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Generate a list of indexes that can be used to retrieve items from the replay memory.

Args:
batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience
item is present are returned.
forbid_last (bool, default=False): Whether the latest element is allowed to be sampled.
If this is true, the last index will always be excluded from the result.

Returns:
indexes (np.ndarray): The list of indexes.
"""
raise NotImplementedError

@abstractmethod
def get_last_index(self) -> int:
"""Get the index of the latest element in the memory.

Returns:
index (int): The index of the latest element in the memory.
"""
raise NotImplementedError


class RandomIndexScheduler(AbsIndexScheduler):
"""Index scheduler that returns random indexes when sampling.
Expand Down Expand Up @@ -93,14 +82,11 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray:
self._size = min(self._size + batch_size, self._capacity)
return indexes

def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}"
assert self._size > 0, "Cannot sample from an empty memory."
return np.random.choice(self._size, size=batch_size, replace=True)

def get_last_index(self) -> int:
raise NotImplementedError


class FIFOIndexScheduler(AbsIndexScheduler):
"""First-in-first-out index scheduler.
Expand Down Expand Up @@ -135,19 +121,15 @@ def get_put_indexes(self, batch_size: int) -> np.ndarray:
self._head = (self._head + overwrite) % self._capacity
return self.get_put_indexes(batch_size)

def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
tmp = self._tail if not forbid_last else (self._tail - 1) % self._capacity
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
indexes = (
np.arange(self._head, tmp)
if tmp > self._head
else np.concatenate([np.arange(self._head, self._capacity), np.arange(tmp)])
np.arange(self._head, self._tail)
if self._tail > self._head
else np.concatenate([np.arange(self._head, self._capacity), np.arange(self._tail)])
)
self._head = tmp
self._head = self._tail
return indexes

def get_last_index(self) -> int:
return (self._tail - 1) % self._capacity


class AbsReplayMemory(object, metaclass=ABCMeta):
"""Abstract replay memory class with basic interfaces.
Expand Down Expand Up @@ -176,9 +158,9 @@ def _get_put_indexes(self, batch_size: int) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_put_indexes(batch_size)

def _get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_sample_indexes(batch_size, forbid_last)
return self._idx_scheduler.get_sample_indexes(batch_size)


class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
Expand Down Expand Up @@ -273,7 +255,7 @@ def sample(self, batch_size: int = None) -> TransitionBatch:
Returns:
batch (TransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
indexes = self._get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)

def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
Expand All @@ -298,10 +280,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
old_logps=self._old_logps[indexes],
)

@abstractmethod
def _get_forbid_last(self) -> bool:
raise NotImplementedError


class RandomReplayMemory(ReplayMemory):
def __init__(
Expand All @@ -318,15 +296,11 @@ def __init__(
RandomIndexScheduler(capacity, random_overwrite),
)
self._random_overwrite = random_overwrite
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)

@property
def random_overwrite(self) -> bool:
return self._random_overwrite

def _get_forbid_last(self) -> bool:
return False


class FIFOReplayMemory(ReplayMemory):
def __init__(
Expand All @@ -342,9 +316,6 @@ def __init__(
FIFOIndexScheduler(capacity),
)

def _get_forbid_last(self) -> bool:
return not self._terminals[self._idx_scheduler.get_last_index()]


class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
"""In-memory experience storage facility for a multi trainer.
Expand Down Expand Up @@ -446,7 +417,7 @@ def sample(self, batch_size: int = None) -> MultiTransitionBatch:
Returns:
batch (MultiTransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
indexes = self._get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)

def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
Expand All @@ -470,10 +441,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
next_agent_states=[state[indexes] for state in self._next_agent_states],
)

@abstractmethod
def _get_forbid_last(self) -> bool:
raise NotImplementedError


class RandomMultiReplayMemory(MultiReplayMemory):
def __init__(
Expand All @@ -492,15 +459,11 @@ def __init__(
agent_states_dims,
)
self._random_overwrite = random_overwrite
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)

@property
def random_overwrite(self) -> bool:
return self._random_overwrite

def _get_forbid_last(self) -> bool:
return False


class FIFOMultiReplayMemory(MultiReplayMemory):
def __init__(
Expand All @@ -517,6 +480,3 @@ def __init__(
FIFOIndexScheduler(capacity),
agent_states_dims,
)

def _get_forbid_last(self) -> bool:
return not self._terminals[self._idx_scheduler.get_last_index()]
2 changes: 1 addition & 1 deletion tests/rl/gym_wrapper/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
env_conf = {
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
"start_tick": 0,
"durations": 5000,
"durations": 1000,
"options": {
"random_seed": None,
},
Expand Down
2 changes: 2 additions & 0 deletions tests/rl/gym_wrapper/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def post_collect(self, info_list: list, ep: int) -> None:
self.metrics.update(cur)
# clear validation metrics
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
self._sample_rewards.clear()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug fix to record the-ep-only statistics?


def post_evaluate(self, info_list: list, ep: int) -> None:
cur = {
Expand All @@ -85,3 +86,4 @@ def post_evaluate(self, info_list: list, ep: int) -> None:
}
self._eval_rewards.clear()
self.metrics.update(cur)
self._eval_rewards.clear()
lihuoran marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions tests/rl/tasks/ac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler

actor_net_conf = {
"hidden_dims": [64, 64],
"hidden_dims": [64, 32],
"activation": torch.nn.Tanh,
}
critic_net_conf = {
"hidden_dims": [64, 64],
"hidden_dims": [64, 32],
"activation": torch.nn.Tanh,
}
actor_learning_rate = 3e-4
Expand Down
2 changes: 2 additions & 0 deletions tests/rl/tasks/ppo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer:
return PPOTrainer(
name=name,
reward_discount=0.99,
replay_memory_capacity=4000,
batch_size=4000,
params=PPOParams(
get_v_critic_net_func=lambda: MyVCriticNet(state_dim),
grad_iters=80,
Expand Down
4 changes: 2 additions & 2 deletions tests/rl/tasks/ppo/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ scenario_path: "tests/rl/tasks/ppo"
log_path: "tests/rl/log/ppo"
main:
num_episodes: 1000
num_steps: null
num_steps: 4000
eval_schedule: 5
num_eval_episodes: 10
min_n_sample: 5000
min_n_sample: 1
logging:
stdout: INFO
file: DEBUG
Expand Down