Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Improve PPO implementation to avoid potential timeout for RPC #81

Merged
merged 2 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/atari/ppo/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def main(cfg):
replay_buffer=a_rb,
controller=a_ctrl,
optimizer=optimizer,
batch_size=cfg.batch_size,
learning_starts=cfg.get("learning_starts", None),
batch_size=cfg.batch_size,
push_every_n_steps=cfg.push_every_n_steps)
t_agent_fac = AgentFactory(PPOAgent, t_model, replay_buffer=t_rb)
e_agent_fac = AgentFactory(PPOAgent, e_model, deterministic_policy=False)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/ppo/atari_ppo_rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def main(cfg):
replay_buffer=a_rb,
controller=a_ctrl,
optimizer=optimizer,
batch_size=cfg.batch_size,
learning_starts=cfg.get("learning_starts", None),
batch_size=cfg.batch_size,
push_every_n_steps=cfg.push_every_n_steps)
t_agent_fac = AgentFactory(PPORNDAgent, t_model, replay_buffer=t_rb)
e_agent_fac = AgentFactory(PPORNDAgent, e_model, deterministic_policy=False)
Expand Down
40 changes: 33 additions & 7 deletions rlmeta/agents/ppo/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def __init__(self,
normalize_advantage: bool = True,
learning_starts: Optional[int] = None,
batch_size: int = 512,
local_batch_size: int = 1024,
max_grad_norm: float = 1.0,
push_every_n_steps: int = 1) -> None:
super().__init__()

self._model = model
self._deterministic_policy = deterministic_policy
self._deterministic_policy = torch.tensor([deterministic_policy])

self._replay_buffer = replay_buffer
self._controller = controller
Expand All @@ -70,6 +71,7 @@ def __init__(self,

self._learning_starts = learning_starts
self._batch_size = batch_size
self._local_batch_size = local_batch_size
self._max_grad_norm = max_grad_norm
self._push_every_n_steps = push_every_n_steps

Expand All @@ -82,14 +84,13 @@ def reset(self) -> None:

def act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, v = self._model.act(
obs, torch.tensor([self._deterministic_policy]))
action, logpi, v = self._model.act(obs, self._deterministic_policy)
return Action(action, info={"logpi": logpi, "v": v})

async def async_act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, v = await self._model.async_act(
obs, torch.tensor([self._deterministic_policy]))
obs, self._deterministic_policy)
return Action(action, info={"logpi": logpi, "v": v})

async def async_observe_init(self, timestep: TimeStep) -> None:
Expand All @@ -116,15 +117,15 @@ def update(self) -> None:
return
if self._replay_buffer is not None:
replay = self._make_replay()
self._replay_buffer.extend(replay)
self._send_replay(replay)
self._trajectory.clear()

async def async_update(self) -> None:
if not self._trajectory or not self._trajectory[-1]["done"]:
return
if self._replay_buffer is not None:
replay = self._make_replay()
await self._replay_buffer.async_extend(replay)
replay = await self._async_make_replay()
await self._async_send_replay(replay)
self._trajectory.clear()

def train(self, num_steps: int) -> Optional[StatsDict]:
Expand Down Expand Up @@ -187,6 +188,31 @@ def _make_replay(self) -> List[NestedTensor]:
cur.pop("done")
return self._trajectory

async def _async_make_replay(self) -> List[NestedTensor]:
return self._make_replay()

def _send_replay(self, replay: List[NestedTensor]) -> None:
batch = []
while replay:
batch.append(replay.pop())
if len(batch) >= self._local_batch_size:
self._replay_buffer.extend(batch)
batch.clear()
if batch:
self._replay_buffer.extend(batch)
batch.clear()

async def _async_send_replay(self, replay: List[NestedTensor]) -> None:
batch = []
while replay:
batch.append(replay.pop())
if len(batch) >= self._local_batch_size:
await self._replay_buffer.async_extend(batch)
batch.clear()
if batch:
await self._replay_buffer.async_extend(batch)
batch.clear()

def _calculate_gae_and_return(
self,
values: Sequence[Union[float, torch.Tensor]],
Expand Down
58 changes: 51 additions & 7 deletions rlmeta/agents/ppo/ppo_rnd_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
normalize_advantage: bool = True,
learning_starts: Optional[int] = None,
batch_size: int = 128,
local_batch_size: int = 1024,
max_grad_norm: float = 1.0,
push_every_n_steps: int = 1,
collate_fn: Optional[Callable[[Sequence[NestedTensor]],
Expand All @@ -49,7 +50,8 @@ def __init__(
optimizer, gamma, gae_lambda, ratio_clipping_eps,
value_clipping_eps, vf_loss_coeff, entropy_coeff,
rescale_reward, normalize_advantage, learning_starts,
batch_size, max_grad_norm, push_every_n_steps)
batch_size, local_batch_size, max_grad_norm,
push_every_n_steps)

self._intrinsic_advantage_coeff = intrinsic_advantage_coeff

Expand All @@ -64,7 +66,7 @@ def __init__(
def act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, ext_v, int_v = self._model.act(
obs, torch.tensor([self._deterministic_policy]))
obs, self._deterministic_policy)
return Action(action,
info={
"logpi": logpi,
Expand All @@ -75,7 +77,7 @@ def act(self, timestep: TimeStep) -> Action:
async def async_act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action, logpi, ext_v, int_v = await self._model.async_act(
obs, torch.tensor([self._deterministic_policy]))
obs, self._deterministic_policy)
return Action(action,
info={
"logpi": logpi,
Expand All @@ -100,16 +102,28 @@ def _make_replay(self) -> List[NestedTensor]:
next_obs = [
self._trajectory[i]["obs"] for i in range(1, len(self._trajectory))
]
next_obs = self._collate_fn(next_obs)
int_rewards = self._model.intrinsic_reward(next_obs)
int_rewards[-1] = 0.0
int_rewards = self._compute_intrinsic_rewards(next_obs,
done_at_last=True)
return self._make_replay_impl(int_rewards)

async def _async_make_replay(self) -> List[NestedTensor]:
next_obs = [
self._trajectory[i]["obs"] for i in range(1, len(self._trajectory))
]
int_rewards = await self._async_compute_intrinsic_rewards(
next_obs, done_at_last=True)
return self._make_replay_impl(int_rewards)

def _make_replay_impl(
self,
intrinsic_rewards: Sequence[NestedTensor]) -> List[NestedTensor]:
self._trajectory.pop()

ext_adv, ext_ret = self._calculate_gae_and_return(
[x["ext_v"] for x in self._trajectory],
[x["reward"] for x in self._trajectory], self._ext_reward_rescaler)
int_adv, int_ret = self._calculate_gae_and_return(
[x["int_v"] for x in self._trajectory], torch.unbind(int_rewards),
[x["int_v"] for x in self._trajectory], intrinsic_rewards,
self._int_reward_rescaler)

for cur, ext_a, ext_r, int_a, int_r in zip(self._trajectory, ext_adv,
Expand All @@ -123,6 +137,36 @@ def _make_replay(self) -> List[NestedTensor]:

return self._trajectory

def _compute_intrinsic_rewards(
self,
obs: Sequence[NestedTensor],
done_at_last: bool = True) -> List[torch.Tensor]:
int_rewards = []
n = len(obs)
obs = nested_utils.collate_nested(self._collate_fn, obs)
for i in range(0, n, self._local_batch_size):
batch = obs[i:i + self._local_batch_size]
cur_rewards = self._model.intrinsic_reward(batch)
int_rewards.extend(torch.unbind(cur_rewards))
if done_at_last:
int_rewards[-1].zero_()
return int_rewards

async def _async_compute_intrinsic_rewards(
self,
obs: Sequence[NestedTensor],
done_at_last: bool = True) -> List[torch.Tensor]:
int_rewards = []
n = len(obs)
obs = nested_utils.collate_nested(self._collate_fn, obs)
for i in range(0, n, self._local_batch_size):
batch = obs[i:i + self._local_batch_size]
cur_rewards = await self._model.async_intrinsic_reward(batch)
int_rewards.extend(torch.unbind(cur_rewards))
if done_at_last:
int_rewards[-1].zero_()
return int_rewards

def _train_step(self, batch: NestedTensor) -> Dict[str, float]:
batch = nested_utils.map_nested(lambda x: x.to(self.device()), batch)
self._optimizer.zero_grad()
Expand Down