From 4f57aedf1478de0167399aebf0fdc4820c36bf93 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 22 Sep 2022 22:29:35 -0700 Subject: [PATCH 1/3] Refactor PPOAgent and PPORNDAgent --- examples/atari/ppo/atari_ppo_rnd.py | 4 +- examples/atari/ppo/atari_ppo_rnd_model.py | 35 ++-- rlmeta/agents/ppo/ppo_agent.py | 199 +++++++++++----------- rlmeta/agents/ppo/ppo_rnd_agent.py | 109 ++++++------ rlmeta/agents/ppo/ppo_rnd_model.py | 6 +- 5 files changed, 170 insertions(+), 183 deletions(-) diff --git a/examples/atari/ppo/atari_ppo_rnd.py b/examples/atari/ppo/atari_ppo_rnd.py index 9729434..d009d02 100644 --- a/examples/atari/ppo/atari_ppo_rnd.py +++ b/examples/atari/ppo/atari_ppo_rnd.py @@ -28,6 +28,7 @@ from rlmeta.core.server import Server, ServerList from rlmeta.samplers import UniformSampler from rlmeta.storage import TensorCircularBuffer +from rlmeta.utils.optimizer_utils import make_optimizer @hydra.main(config_path="./conf", config_name="conf_ppo") @@ -36,7 +37,8 @@ def main(cfg): env = atari_wrappers.make_atari(cfg.env) train_model = AtariPPORNDModel(env.action_space.n).to(cfg.train_device) - optimizer = torch.optim.Adam(train_model.parameters(), lr=cfg.lr) + optimizer = make_optimizer(cfg.optimizer.name, train_model.parameters(), + cfg.optimizer.args) infer_model = copy.deepcopy(train_model).to(cfg.infer_device) diff --git a/examples/atari/ppo/atari_ppo_rnd_model.py b/examples/atari/ppo/atari_ppo_rnd_model.py index c016cbf..fee7291 100644 --- a/examples/atari/ppo/atari_ppo_rnd_model.py +++ b/examples/atari/ppo/atari_ppo_rnd_model.py @@ -19,28 +19,24 @@ class AtariPPORNDModel(PPORNDModel): - def __init__(self, - action_dim: int, - observation_normalization: bool = False) -> None: + def __init__(self, action_dim: int) -> None: super().__init__() self.action_dim = action_dim - self.observation_normalization = observation_normalization - if self.observation_normalization: - self.obs_rescaler = MomentsRescaler(size=(4, 84, 84)) - self.policy_net = AtariBackbone() - self.target_net = AtariBackbone() - self.predict_net = AtariBackbone() - self.linear_p = nn.Linear(self.policy_net.output_dim, self.action_dim) - self.linear_ext_v = nn.Linear(self.policy_net.output_dim, 1) - self.linear_int_v = nn.Linear(self.policy_net.output_dim, 1) + self.ppo_net = AtariBackbone() + self.tgt_net = AtariBackbone() + self.prd_net = AtariBackbone() + + self.linear_p = nn.Linear(self.ppo_net.output_dim, self.action_dim) + self.linear_ext_v = nn.Linear(self.ppo_net.output_dim, 1) + self.linear_int_v = nn.Linear(self.ppo_net.output_dim, 1) def forward( self, obs: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x = obs.float() / 255.0 - h = self.policy_net(x) + h = self.ppo_net(x) p = self.linear_p(h) logpi = F.log_softmax(p, dim=-1) ext_v = self.linear_ext_v(h) @@ -64,7 +60,7 @@ def act( action = torch.where(d, greedy_action, sample_action) logpi = logpi.gather(dim=-1, index=action) - return action.cpu(), logpi.cpu(), ext_v.cpu(), int_v.cpu() + return action.cpu(), logpi.cpu(), ext_v.cpu(), int_v.cpu() @remote.remote_method(batch_size=None) def intrinsic_reward(self, obs: torch.Tensor) -> torch.Tensor: @@ -77,13 +73,8 @@ def rnd_loss(self, obs: torch.Tensor) -> torch.Tensor: def _rnd_error(self, obs: torch.Tensor) -> torch.Tensor: x = obs.float() / 255.0 - if self.observation_normalization: - self.obs_rescaler.update(x) - x = self.obs_rescaler.rescale(x) - with torch.no_grad(): - target = self.target_net(x) - pred = self.predict_net(x) - err = (pred - target).square().mean(-1, keepdim=True) - + tgt = self.tgt_net(x) + prd = self.prd_net(x) + err = (prd - tgt).square().mean(-1, keepdim=True) return err diff --git a/rlmeta/agents/ppo/ppo_agent.py b/rlmeta/agents/ppo/ppo_agent.py index 34aa0df..09dcfff 100644 --- a/rlmeta/agents/ppo/ppo_agent.py +++ b/rlmeta/agents/ppo/ppo_agent.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from rich.console import Console from rich.progress import track @@ -36,112 +37,106 @@ def __init__(self, replay_buffer: Optional[ReplayBufferLike] = None, controller: Optional[ControllerLike] = None, optimizer: Optional[torch.optim.Optimizer] = None, - batch_size: int = 128, - grad_clip: float = 1.0, gamma: float = 0.99, gae_lambda: float = 0.95, - eps_clip: float = 0.2, + ratio_clipping_eps: float = 0.2, + value_clipping_eps: Optional[float] = 0.2, vf_loss_coeff: float = 0.5, entropy_coeff: float = 0.01, - reward_rescaling: bool = True, - advantage_normalization: bool = True, - value_clip: bool = True, + rescale_reward: bool = True, + normalize_advantage: bool = True, learning_starts: Optional[int] = None, + batch_size: int = 512, + max_grad_norm: float = 1.0, push_every_n_steps: int = 1) -> None: super().__init__() - self.model = model - self.deterministic_policy = deterministic_policy - - self.replay_buffer = replay_buffer - self.controller = controller - - self.optimizer = optimizer - self.batch_size = batch_size - self.grad_clip = grad_clip - - self.gamma = gamma - self.gae_lambda = gae_lambda - self.eps_clip = eps_clip - self.vf_loss_coeff = vf_loss_coeff - self.entropy_coeff = entropy_coeff - self.reward_rescaling = reward_rescaling - if self.reward_rescaling: - self.reward_rescaler = RMSRescaler(size=1) - self.advantage_normalization = advantage_normalization - self.value_clip = value_clip - - self.learning_starts = learning_starts - self.push_every_n_steps = push_every_n_steps - self.done = False - self.trajectory = [] - self.step_counter = 0 - + self._model = model + self._deterministic_policy = deterministic_policy + + self._replay_buffer = replay_buffer + self._controller = controller + self._optimizer = optimizer + + self._gamma = gamma + self._gae_lambda = gae_lambda + self._ratio_clipping_eps = ratio_clipping_eps + self._value_clipping_eps = value_clipping_eps + self._vf_loss_coeff = vf_loss_coeff + self._entropy_coeff = entropy_coeff + self._rescale_reward = rescale_reward + self._reward_rescaler = RMSRescaler(size=1) if rescale_reward else None + self._normalize_advantage = normalize_advantage + + self._learning_starts = learning_starts + self._batch_size = batch_size + self._max_grad_norm = max_grad_norm + self._push_every_n_steps = push_every_n_steps + + self._trajectory = [] + self._step_counter = 0 self._device = None def reset(self) -> None: - self.step_counter = 0 + self._step_counter = 0 def act(self, timestep: TimeStep) -> Action: obs = timestep.observation - action, logpi, v = self.model.act( - obs, torch.tensor([self.deterministic_policy])) + action, logpi, v = self._model.act( + obs, torch.tensor([self._deterministic_policy])) return Action(action, info={"logpi": logpi, "v": v}) async def async_act(self, timestep: TimeStep) -> Action: obs = timestep.observation - action, logpi, v = await self.model.async_act( - obs, torch.tensor([self.deterministic_policy])) + action, logpi, v = await self._model.async_act( + obs, torch.tensor([self._deterministic_policy])) return Action(action, info={"logpi": logpi, "v": v}) async def async_observe_init(self, timestep: TimeStep) -> None: obs, _, done, _ = timestep if done: - self.trajectory = [] + self._trajectory = [] else: - self.trajectory = [{"obs": obs}] + self._trajectory = [{"obs": obs, "done": done}] async def async_observe(self, action: Action, next_timestep: TimeStep) -> None: act, info = action obs, reward, done, _ = next_timestep - cur = self.trajectory[-1] + cur = self._trajectory[-1] cur["action"] = act cur["logpi"] = info["logpi"] cur["v"] = info["v"] cur["reward"] = reward - - if not done: - self.trajectory.append({"obs": obs}) - self.done = done + self._trajectory.append({"obs": obs, "done": done}) def update(self) -> None: - if not self.done: + if not self._trajectory or not self._trajectory[-1]["done"]: return - if self.replay_buffer is not None: + if self._replay_buffer is not None: replay = self._make_replay() - self.replay_buffer.extend(replay) - self.trajectory.clear() + self._replay_buffer.extend(replay) + self._trajectory.clear() async def async_update(self) -> None: - if not self.done: + if not self._trajectory or not self._trajectory[-1]["done"]: return - if self.replay_buffer is not None: + if self._replay_buffer is not None: replay = self._make_replay() - await self.replay_buffer.async_extend(replay) - self.trajectory.clear() + await self._replay_buffer.async_extend(replay) + self._trajectory.clear() def train(self, num_steps: int) -> Optional[StatsDict]: - self.controller.set_phase(Phase.TRAIN) + self._controller.set_phase(Phase.TRAIN) - self.replay_buffer.warm_up(self.learning_starts) + self._replay_buffer.warm_up(self._learning_starts) stats = StatsDict() console.log(f"Training for num_steps = {num_steps}") for _ in track(range(num_steps), description="Training..."): t0 = time.perf_counter() - _, batch, _ = self.replay_buffer.sample(self.batch_size) + _, batch, _ = self._replay_buffer.sample(self._batch_size) t1 = time.perf_counter() step_stats = self._train_step(batch) t2 = time.perf_counter() @@ -152,13 +147,13 @@ def train(self, num_steps: int) -> Optional[StatsDict]: stats.extend(step_stats) stats.extend(time_stats) - self.step_counter += 1 - if self.step_counter % self.push_every_n_steps == 0: - self.model.push() + self._step_counter += 1 + if self._step_counter % self._push_every_n_steps == 0: + self._model.push() - episode_stats = self.controller.stats(Phase.TRAIN) + episode_stats = self._controller.stats(Phase.TRAIN) stats.update(episode_stats) - self.controller.reset_phase(Phase.TRAIN) + self._controller.reset_phase(Phase.TRAIN) return stats @@ -166,30 +161,31 @@ def eval(self, num_episodes: Optional[int] = None, keep_training_loops: bool = False) -> Optional[StatsDict]: if keep_training_loops: - self.controller.set_phase(Phase.BOTH) + self._controller.set_phase(Phase.BOTH) else: - self.controller.set_phase(Phase.EVAL) - self.controller.reset_phase(Phase.EVAL, limit=num_episodes) - while self.controller.count(Phase.EVAL) < num_episodes: + self._controller.set_phase(Phase.EVAL) + self._controller.reset_phase(Phase.EVAL, limit=num_episodes) + while self._controller.count(Phase.EVAL) < num_episodes: time.sleep(1) - stats = self.controller.stats(Phase.EVAL) + stats = self._controller.stats(Phase.EVAL) return stats def device(self) -> torch.device: if self._device is None: - self._device = next(self.model.parameters()).device + self._device = next(self._model.parameters()).device return self._device def _make_replay(self) -> List[NestedTensor]: + self._trajectory.pop() adv, ret = self._calculate_gae_and_return( - [x["v"] for x in self.trajectory], - [x["reward"] for x in self.trajectory], - self.reward_rescaler if self.reward_rescaling else None) - for cur, a, r in zip(self.trajectory, adv, ret): + [x["v"] for x in self._trajectory], + [x["reward"] for x in self._trajectory], self._reward_rescaler) + for cur, a, r in zip(self._trajectory, adv, ret): cur["gae"] = a cur["ret"] = r cur.pop("reward") - return self.trajectory + cur.pop("done") + return self._trajectory def _calculate_gae_and_return( self, @@ -206,8 +202,8 @@ def _calculate_gae_and_return( v = value if reward_rescaler is not None: v = reward_rescaler.recover(v) - delta = reward + self.gamma * next_v - v - gae = delta + self.gamma * self.gae_lambda * gae + delta = reward + self._gamma * next_v - v + gae = delta + self._gamma * self._gae_lambda * gae adv.append(gae) ret.append(gae + v) @@ -221,25 +217,26 @@ def _calculate_gae_and_return( def _train_step(self, batch: NestedTensor) -> Dict[str, float]: batch = nested_utils.map_nested(lambda x: x.to(self.device()), batch) - self.optimizer.zero_grad() + self._optimizer.zero_grad() obs = batch["obs"] act = batch["action"] - old_logpi = batch["logpi"] adv = batch["gae"] ret = batch["ret"] - logpi, v = self._model_forward(obs) + behavior_logpi = batch["logpi"] + behavior_v = batch["v"] + logpi, v = self._model_forward(obs) policy_loss, ratio = self._policy_loss(logpi.gather(dim=-1, index=act), - old_logpi, adv) - value_loss = self._value_loss(ret, v, batch.get("v", None)) + behavior_logpi, adv) + value_loss = self._value_loss(ret, v, behavior_v) entropy = self._entropy(logpi) - loss = policy_loss + (self.vf_loss_coeff * - value_loss) - (self.entropy_coeff * entropy) + loss = policy_loss + (self._vf_loss_coeff * + value_loss) - (self._entropy_coeff * entropy) loss.backward() - grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), - self.grad_clip) - self.optimizer.step() + grad_norm = nn.utils.clip_grad_norm_(self._model.parameters(), + self._max_grad_norm) + self._optimizer.step() return { "return": ret.detach().mean().item(), @@ -252,20 +249,19 @@ def _train_step(self, batch: NestedTensor) -> Dict[str, float]: } def _model_forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, ...]: - return self.model(obs) + return self._model(obs) - def _policy_loss(self, logpi: torch.Tensor, old_logpi: torch.Tensor, + def _policy_loss(self, logpi: torch.Tensor, behavior_logpi: torch.Tensor, adv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.advantage_normalization: - # Advantage normalization + if self._normalize_advantage: std, mean = torch.std_mean(adv, unbiased=False) adv = (adv - mean) / std - # Policy clip - ratio = (logpi - old_logpi).exp() - ratio_clamp = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) + ratio = (logpi - behavior_logpi).exp() + clipped_ratio = ratio.clamp(1.0 - self._ratio_clipping_eps, + 1.0 + self._ratio_clipping_eps) surr1 = ratio * adv - surr2 = ratio_clamp * adv + surr2 = clipped_ratio * adv policy_loss = -torch.min(surr1, surr2).mean() return policy_loss, ratio @@ -273,16 +269,15 @@ def _policy_loss(self, logpi: torch.Tensor, old_logpi: torch.Tensor, def _value_loss(self, ret: torch.Tensor, v: torch.Tensor, - old_v: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.value_clip: - # Value clip - v_clamp = old_v + (v - old_v).clamp(-self.eps_clip, self.eps_clip) - vf1 = (ret - v).square() - vf2 = (ret - v_clamp).square() - value_loss = torch.max(vf1, vf2).mean() * 0.5 - else: - value_loss = (ret - v).square().mean() * 0.5 - return value_loss + behavior_v: Optional[torch.Tensor] = None) -> torch.Tensor: + if self._value_clipping_eps is None: + return F.mse_loss(v, ret) + + clipped_v = behavior_v + torch.clamp( + v - behavior_v, -self._value_clipping_eps, self._value_clipping_eps) + vf1 = F.mse_loss(v, ret, reduction="none") + vf2 = F.mse_loss(clipped_v, ret, reduction="none") + return torch.max(vf1, vf2).mean() def _entropy(self, logpi: torch.Tensor) -> torch.Tensor: return -(logpi.exp() * logpi).sum(dim=-1).mean() diff --git a/rlmeta/agents/ppo/ppo_rnd_agent.py b/rlmeta/agents/ppo/ppo_rnd_agent.py index 15add87..e0c67c9 100644 --- a/rlmeta/agents/ppo/ppo_rnd_agent.py +++ b/rlmeta/agents/ppo/ppo_rnd_agent.py @@ -29,44 +29,42 @@ def __init__( replay_buffer: Optional[ReplayBufferLike] = None, controller: Optional[ControllerLike] = None, optimizer: Optional[torch.optim.Optimizer] = None, - batch_size: int = 128, - grad_clip: float = 1.0, gamma: float = 0.99, gae_lambda: float = 0.95, - eps_clip: float = 0.2, + ratio_clipping_eps: float = 0.2, + value_clipping_eps: Optional[float] = 0.2, intrinsic_advantage_coeff: float = 0.5, vf_loss_coeff: float = 0.5, entropy_coeff: float = 0.01, - advantage_normalization: bool = True, - reward_rescaling: bool = True, - value_clip: bool = True, + rescale_reward: bool = True, + normalize_advantage: bool = True, learning_starts: Optional[int] = None, + batch_size: int = 128, + max_grad_norm: float = 1.0, push_every_n_steps: int = 1, collate_fn: Optional[Callable[[Sequence[NestedTensor]], NestedTensor]] = None ) -> None: super().__init__(model, deterministic_policy, replay_buffer, controller, - optimizer, batch_size, grad_clip, gamma, gae_lambda, - eps_clip, vf_loss_coeff, entropy_coeff, - advantage_normalization, reward_rescaling, value_clip, - learning_starts, push_every_n_steps) + optimizer, gamma, gae_lambda, ratio_clipping_eps, + value_clipping_eps, vf_loss_coeff, entropy_coeff, + rescale_reward, normalize_advantage, learning_starts, + batch_size, max_grad_norm, push_every_n_steps) - self.intrinsic_advantage_coeff = intrinsic_advantage_coeff + self._intrinsic_advantage_coeff = intrinsic_advantage_coeff - if self.reward_rescaling: - self.reward_rescaler = None - self.ext_reward_rescaler = RMSRescaler(size=1) - self.int_reward_rescaler = RMSRescaler(size=1) + self._reward_rescaler = None + self._ext_reward_rescaler = RMSRescaler( + size=1) if rescale_reward else None + self._int_reward_rescaler = RMSRescaler( + size=1) if rescale_reward else None - if collate_fn is not None: - self.collate_fn = collate_fn - else: - self.collate_fn = data_utils.stack_tensors + self._collate_fn = torch.stack if collate_fn is None else collate_fn def act(self, timestep: TimeStep) -> Action: obs = timestep.observation - action, logpi, ext_v, int_v = self.model.act( - obs, torch.tensor([self.deterministic_policy])) + action, logpi, ext_v, int_v = self._model.act( + obs, torch.tensor([self._deterministic_policy])) return Action(action, info={ "logpi": logpi, @@ -76,8 +74,8 @@ def act(self, timestep: TimeStep) -> Action: async def async_act(self, timestep: TimeStep) -> Action: obs = timestep.observation - action, logpi, ext_v, int_v = await self.model.async_act( - obs, torch.tensor([self.deterministic_policy])) + action, logpi, ext_v, int_v = await self._model.async_act( + obs, torch.tensor([self._deterministic_policy])) return Action(action, info={ "logpi": logpi, @@ -90,73 +88,72 @@ async def async_observe(self, action: Action, act, info = action obs, reward, done, _ = next_timestep - cur = self.trajectory[-1] + cur = self._trajectory[-1] cur["reward"] = reward cur["action"] = act cur["logpi"] = info["logpi"] cur["ext_v"] = info["ext_v"] cur["int_v"] = info["int_v"] - cur["next_obs"] = obs - - if not done: - self.trajectory.append({"obs": obs}) - self.done = done + self._trajectory.append({"obs": obs, "done": done}) def _make_replay(self) -> List[NestedTensor]: - next_obs = [x["next_obs"] for x in self.trajectory] - next_obs = self.collate_fn(next_obs) - int_rewards = self.model.intrinsic_reward(next_obs) + next_obs = [ + self._trajectory[i]["obs"] for i in range(1, len(self._trajectory)) + ] + next_obs = self._collate_fn(next_obs) + int_rewards = self._model.intrinsic_reward(next_obs) + int_rewards[-1] = 0.0 + self._trajectory.pop() ext_adv, ext_ret = self._calculate_gae_and_return( - [x["ext_v"] for x in self.trajectory], - [x["reward"] for x in self.trajectory], - self.ext_reward_rescaler if self.reward_rescaling else None) + [x["ext_v"] for x in self._trajectory], + [x["reward"] for x in self._trajectory], self._ext_reward_rescaler) int_adv, int_ret = self._calculate_gae_and_return( - [x["int_v"] for x in self.trajectory], torch.unbind(int_rewards), - self.int_reward_rescaler if self.reward_rescaling else None) + [x["int_v"] for x in self._trajectory], torch.unbind(int_rewards), + self._int_reward_rescaler) - for cur, ext_a, ext_r, int_a, int_r in zip(self.trajectory, ext_adv, + for cur, ext_a, ext_r, int_a, int_r in zip(self._trajectory, ext_adv, ext_ret, int_adv, int_ret): cur["ext_gae"] = ext_a cur["ext_ret"] = ext_r cur["int_gae"] = int_a cur["int_ret"] = int_r cur.pop("reward") + cur.pop("done") - return self.trajectory + return self._trajectory def _train_step(self, batch: NestedTensor) -> Dict[str, float]: batch = nested_utils.map_nested(lambda x: x.to(self.device()), batch) - self.optimizer.zero_grad() + self._optimizer.zero_grad() obs = batch["obs"] act = batch["action"] - old_logpi = batch["logpi"] ext_adv = batch["ext_gae"] ext_ret = batch["ext_ret"] int_adv = batch["int_gae"] int_ret = batch["int_ret"] - next_obs = batch["next_obs"] - logpi, ext_v, int_v = self._model_forward(obs) + behavior_logpi = batch["logpi"] + behavior_ext_v = batch["ext_v"] + behavior_int_v = batch["int_v"] - adv = ext_adv + self.intrinsic_advantage_coeff * int_adv + logpi, ext_v, int_v = self._model_forward(obs) + adv = ext_adv + self._intrinsic_advantage_coeff * int_adv policy_loss, ratio = self._policy_loss(logpi.gather(dim=-1, index=act), - old_logpi, adv) + behavior_logpi, adv) - ext_value_loss = self._value_loss(ext_ret, ext_v, - batch.get("ext_v", None)) - int_value_loss = self._value_loss(int_ret, int_v, - batch.get("int_v", None)) + ext_value_loss = self._value_loss(ext_ret, ext_v, behavior_ext_v) + int_value_loss = self._value_loss(int_ret, int_v, behavior_int_v) value_loss = ext_value_loss + int_value_loss entropy = self._entropy(logpi) - rnd_loss = self._rnd_loss(next_obs) + rnd_loss = self._rnd_loss(obs) - loss = policy_loss + (self.vf_loss_coeff * value_loss) - ( - self.entropy_coeff * entropy) + rnd_loss + loss = policy_loss + (self._vf_loss_coeff * value_loss) - ( + self._entropy_coeff * entropy) + rnd_loss loss.backward() - grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), - self.grad_clip) - self.optimizer.step() + grad_norm = nn.utils.clip_grad_norm_(self._model.parameters(), + self._max_grad_norm) + self._optimizer.step() return { "ext_return": ext_ret.detach().mean().item(), @@ -173,4 +170,4 @@ def _train_step(self, batch: NestedTensor) -> Dict[str, float]: } def _rnd_loss(self, next_obs: torch.Tensor) -> torch.Tensor: - return self.model.rnd_loss(next_obs) + return self._model.rnd_loss(next_obs) diff --git a/rlmeta/agents/ppo/ppo_rnd_model.py b/rlmeta/agents/ppo/ppo_rnd_model.py index 8483d53..8d76873 100644 --- a/rlmeta/agents/ppo/ppo_rnd_model.py +++ b/rlmeta/agents/ppo/ppo_rnd_model.py @@ -46,10 +46,12 @@ def act( deterministic_policy. Returns: - A tuple for pytorch tensor contains [action, logpi, v]. + A tuple for pytorch tensor contains (action, logpi, ext_v, int_v). + action: The final action selected by the model. logpi: The log probility for each action. - v: The value of the current state. + ext_v: The extrinsic value of the current state. + int_v: The intrinsic value of the current state. """ @abc.abstractmethod From 2435d09edc56b3edf051f510aeb79cb96c7f49b0 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 22 Sep 2022 22:36:46 -0700 Subject: [PATCH 2/3] Tiny changes for PPOModel --- examples/atari/ppo/atari_ppo_model.py | 2 +- rlmeta/agents/ppo/ppo_model.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/atari/ppo/atari_ppo_model.py b/examples/atari/ppo/atari_ppo_model.py index bc29965..2328b6c 100644 --- a/examples/atari/ppo/atari_ppo_model.py +++ b/examples/atari/ppo/atari_ppo_model.py @@ -48,4 +48,4 @@ def act( action = torch.where(d, greedy_action, sample_action) logpi = logpi.gather(dim=-1, index=action) - return action.cpu(), logpi.cpu(), v.cpu() + return action.cpu(), logpi.cpu(), v.cpu() diff --git a/rlmeta/agents/ppo/ppo_model.py b/rlmeta/agents/ppo/ppo_model.py index d3287fb..b220e14 100644 --- a/rlmeta/agents/ppo/ppo_model.py +++ b/rlmeta/agents/ppo/ppo_model.py @@ -44,7 +44,8 @@ def act(self, obs: torch.Tensor, deterministic_policy: torch.Tensor, *args, deterministic_policy. Returns: - A tuple for pytorch tensor contains [action, logpi, v]. + A tuple for pytorch tensor contains (action, logpi, v). + action: The final action selected by the model. logpi: The log probility for each action. v: The value of the current state. From 05d5bb634896ef3c43ab271ec6cedca315fd693c Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 22 Sep 2022 23:12:19 -0700 Subject: [PATCH 3/3] Change default TimeLimit of Atari to 3000 --- examples/atari/ppo/conf/conf_ppo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/atari/ppo/conf/conf_ppo.yaml b/examples/atari/ppo/conf/conf_ppo.yaml index e35efec..6f9c8e9 100644 --- a/examples/atari/ppo/conf/conf_ppo.yaml +++ b/examples/atari/ppo/conf/conf_ppo.yaml @@ -11,7 +11,7 @@ train_device: "cuda:0" infer_device: "cuda:1" env: "PongNoFrameskip-v4" -max_episode_steps: 2700 +max_episode_steps: 3000 num_train_rollouts: 32 num_train_workers: 16