Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #81 from xiaomengy/agent_update
Browse files Browse the repository at this point in the history
Improve PPO implementation to avoid potential timeout for RPC
  • Loading branch information
xiaomengy committed Sep 23, 2022
2 parents f80dbe3 + 6bb64fd commit 65ff522
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/atari/ppo/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def main(cfg):
replay_buffer=a_rb,
controller=a_ctrl,
optimizer=optimizer,
batch_size=cfg.batch_size,
learning_starts=cfg.get("learning_starts", None),
batch_size=cfg.batch_size,
push_every_n_steps=cfg.push_every_n_steps)
t_agent_fac = AgentFactory(PPOAgent, t_model, replay_buffer=t_rb)
e_agent_fac = AgentFactory(PPOAgent, e_model, deterministic_policy=False)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/ppo/atari_ppo_rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def main(cfg):
replay_buffer=a_rb,
controller=a_ctrl,
optimizer=optimizer,
batch_size=cfg.batch_size,
learning_starts=cfg.get("learning_starts", None),
batch_size=cfg.batch_size,
push_every_n_steps=cfg.push_every_n_steps)
t_agent_fac = AgentFactory(PPORNDAgent, t_model, replay_buffer=t_rb)
e_agent_fac = AgentFactory(PPORNDAgent, e_model, deterministic_policy=False)
Expand Down
40 changes: 33 additions & 7 deletions rlmeta/agents/ppo/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def __init__(self,
normalize_advantage: bool = True,
learning_starts: Optional[int] = None,
batch_size: int = 512,
local_batch_size: int = 1024,
max_grad_norm: float = 1.0,
push_every_n_steps: int = 1) -> None:
super().__init__()

self._model = model
self._deterministic_policy = deterministic_policy
self._deterministic_policy = torch.tensor([deterministic_policy])

self._replay_buffer = replay_buffer
self._controller = controller
Expand All @@ -70,6 +71,7 @@ def __init__(self,

self._learning_starts = learning_starts
self._batch_size = batch_size
self._local_batch_size = local_batch_size
self._max_grad_norm = max_grad_norm
self._push_every_n_steps = push_every_n_steps

Expand All @@ -82,14 +84,13 @@ def reset(self) -> None:

def act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, v = self._model.act(
obs, torch.tensor([self._deterministic_policy]))
action, logpi, v = self._model.act(obs, self._deterministic_policy)
return Action(action, info={"logpi": logpi, "v": v})

async def async_act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, v = await self._model.async_act(
obs, torch.tensor([self._deterministic_policy]))
obs, self._deterministic_policy)
return Action(action, info={"logpi": logpi, "v": v})

async def async_observe_init(self, timestep: TimeStep) -> None:
Expand All @@ -116,15 +117,15 @@ def update(self) -> None:
return
if self._replay_buffer is not None:
replay = self._make_replay()
self._replay_buffer.extend(replay)
self._send_replay(replay)
self._trajectory.clear()

async def async_update(self) -> None:
if not self._trajectory or not self._trajectory[-1]["done"]:
return
if self._replay_buffer is not None:
replay = self._make_replay()
await self._replay_buffer.async_extend(replay)
replay = await self._async_make_replay()
await self._async_send_replay(replay)
self._trajectory.clear()

def train(self, num_steps: int) -> Optional[StatsDict]:
Expand Down Expand Up @@ -187,6 +188,31 @@ def _make_replay(self) -> List[NestedTensor]:
cur.pop("done")
return self._trajectory

async def _async_make_replay(self) -> List[NestedTensor]:
return self._make_replay()

def _send_replay(self, replay: List[NestedTensor]) -> None:
batch = []
while replay:
batch.append(replay.pop())
if len(batch) >= self._local_batch_size:
self._replay_buffer.extend(batch)
batch.clear()
if batch:
self._replay_buffer.extend(batch)
batch.clear()

async def _async_send_replay(self, replay: List[NestedTensor]) -> None:
batch = []
while replay:
batch.append(replay.pop())
if len(batch) >= self._local_batch_size:
await self._replay_buffer.async_extend(batch)
batch.clear()
if batch:
await self._replay_buffer.async_extend(batch)
batch.clear()

def _calculate_gae_and_return(
self,
values: Sequence[Union[float, torch.Tensor]],
Expand Down
58 changes: 51 additions & 7 deletions rlmeta/agents/ppo/ppo_rnd_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
normalize_advantage: bool = True,
learning_starts: Optional[int] = None,
batch_size: int = 128,
local_batch_size: int = 1024,
max_grad_norm: float = 1.0,
push_every_n_steps: int = 1,
collate_fn: Optional[Callable[[Sequence[NestedTensor]],
Expand All @@ -49,7 +50,8 @@ def __init__(
optimizer, gamma, gae_lambda, ratio_clipping_eps,
value_clipping_eps, vf_loss_coeff, entropy_coeff,
rescale_reward, normalize_advantage, learning_starts,
batch_size, max_grad_norm, push_every_n_steps)
batch_size, local_batch_size, max_grad_norm,
push_every_n_steps)

self._intrinsic_advantage_coeff = intrinsic_advantage_coeff

Expand All @@ -64,7 +66,7 @@ def __init__(
def act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, ext_v, int_v = self._model.act(
obs, torch.tensor([self._deterministic_policy]))
obs, self._deterministic_policy)
return Action(action,
info={
"logpi": logpi,
Expand All @@ -75,7 +77,7 @@ def act(self, timestep: TimeStep) -> Action:
async def async_act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, ext_v, int_v = await self._model.async_act(
obs, torch.tensor([self._deterministic_policy]))
obs, self._deterministic_policy)
return Action(action,
info={
"logpi": logpi,
Expand All @@ -100,16 +102,28 @@ def _make_replay(self) -> List[NestedTensor]:
next_obs = [
self._trajectory[i]["obs"] for i in range(1, len(self._trajectory))
]
next_obs = self._collate_fn(next_obs)
int_rewards = self._model.intrinsic_reward(next_obs)
int_rewards[-1] = 0.0
int_rewards = self._compute_intrinsic_rewards(next_obs,
done_at_last=True)
return self._make_replay_impl(int_rewards)

async def _async_make_replay(self) -> List[NestedTensor]:
next_obs = [
self._trajectory[i]["obs"] for i in range(1, len(self._trajectory))
]
int_rewards = await self._async_compute_intrinsic_rewards(
next_obs, done_at_last=True)
return self._make_replay_impl(int_rewards)

def _make_replay_impl(
self,
intrinsic_rewards: Sequence[NestedTensor]) -> List[NestedTensor]:
self._trajectory.pop()

ext_adv, ext_ret = self._calculate_gae_and_return(
[x["ext_v"] for x in self._trajectory],
[x["reward"] for x in self._trajectory], self._ext_reward_rescaler)
int_adv, int_ret = self._calculate_gae_and_return(
[x["int_v"] for x in self._trajectory], torch.unbind(int_rewards),
[x["int_v"] for x in self._trajectory], intrinsic_rewards,
self._int_reward_rescaler)

for cur, ext_a, ext_r, int_a, int_r in zip(self._trajectory, ext_adv,
Expand All @@ -123,6 +137,36 @@ def _make_replay(self) -> List[NestedTensor]:

return self._trajectory

def _compute_intrinsic_rewards(
self,
obs: Sequence[NestedTensor],
done_at_last: bool = True) -> List[torch.Tensor]:
int_rewards = []
n = len(obs)
obs = nested_utils.collate_nested(self._collate_fn, obs)
for i in range(0, n, self._local_batch_size):
batch = obs[i:i + self._local_batch_size]
cur_rewards = self._model.intrinsic_reward(batch)
int_rewards.extend(torch.unbind(cur_rewards))
if done_at_last:
int_rewards[-1].zero_()
return int_rewards

async def _async_compute_intrinsic_rewards(
self,
obs: Sequence[NestedTensor],
done_at_last: bool = True) -> List[torch.Tensor]:
int_rewards = []
n = len(obs)
obs = nested_utils.collate_nested(self._collate_fn, obs)
for i in range(0, n, self._local_batch_size):
batch = obs[i:i + self._local_batch_size]
cur_rewards = await self._model.async_intrinsic_reward(batch)
int_rewards.extend(torch.unbind(cur_rewards))
if done_at_last:
int_rewards[-1].zero_()
return int_rewards

def _train_step(self, batch: NestedTensor) -> Dict[str, float]:
batch = nested_utils.map_nested(lambda x: x.to(self.device()), batch)
self._optimizer.zero_grad()
Expand Down

0 comments on commit 65ff522

Please sign in to comment.