Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #80 from xiaomengy/agent_update
Browse files Browse the repository at this point in the history
Refactor PPOAgent and PPORNDAgent
  • Loading branch information
xiaomengy committed Sep 23, 2022
2 parents ff9e8ab + 05d5bb6 commit f80dbe3
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 186 deletions.
2 changes: 1 addition & 1 deletion examples/atari/ppo/atari_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def act(
action = torch.where(d, greedy_action, sample_action)
logpi = logpi.gather(dim=-1, index=action)

return action.cpu(), logpi.cpu(), v.cpu()
return action.cpu(), logpi.cpu(), v.cpu()
4 changes: 3 additions & 1 deletion examples/atari/ppo/atari_ppo_rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from rlmeta.core.server import Server, ServerList
from rlmeta.samplers import UniformSampler
from rlmeta.storage import TensorCircularBuffer
from rlmeta.utils.optimizer_utils import make_optimizer


@hydra.main(config_path="./conf", config_name="conf_ppo")
Expand All @@ -36,7 +37,8 @@ def main(cfg):

env = atari_wrappers.make_atari(cfg.env)
train_model = AtariPPORNDModel(env.action_space.n).to(cfg.train_device)
optimizer = torch.optim.Adam(train_model.parameters(), lr=cfg.lr)
optimizer = make_optimizer(cfg.optimizer.name, train_model.parameters(),
cfg.optimizer.args)

infer_model = copy.deepcopy(train_model).to(cfg.infer_device)

Expand Down
35 changes: 13 additions & 22 deletions examples/atari/ppo/atari_ppo_rnd_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,24 @@

class AtariPPORNDModel(PPORNDModel):

def __init__(self,
action_dim: int,
observation_normalization: bool = False) -> None:
def __init__(self, action_dim: int) -> None:
super().__init__()

self.action_dim = action_dim
self.observation_normalization = observation_normalization
if self.observation_normalization:
self.obs_rescaler = MomentsRescaler(size=(4, 84, 84))

self.policy_net = AtariBackbone()
self.target_net = AtariBackbone()
self.predict_net = AtariBackbone()
self.linear_p = nn.Linear(self.policy_net.output_dim, self.action_dim)
self.linear_ext_v = nn.Linear(self.policy_net.output_dim, 1)
self.linear_int_v = nn.Linear(self.policy_net.output_dim, 1)
self.ppo_net = AtariBackbone()
self.tgt_net = AtariBackbone()
self.prd_net = AtariBackbone()

self.linear_p = nn.Linear(self.ppo_net.output_dim, self.action_dim)
self.linear_ext_v = nn.Linear(self.ppo_net.output_dim, 1)
self.linear_int_v = nn.Linear(self.ppo_net.output_dim, 1)

def forward(
self, obs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = obs.float() / 255.0
h = self.policy_net(x)
h = self.ppo_net(x)
p = self.linear_p(h)
logpi = F.log_softmax(p, dim=-1)
ext_v = self.linear_ext_v(h)
Expand All @@ -64,7 +60,7 @@ def act(
action = torch.where(d, greedy_action, sample_action)
logpi = logpi.gather(dim=-1, index=action)

return action.cpu(), logpi.cpu(), ext_v.cpu(), int_v.cpu()
return action.cpu(), logpi.cpu(), ext_v.cpu(), int_v.cpu()

@remote.remote_method(batch_size=None)
def intrinsic_reward(self, obs: torch.Tensor) -> torch.Tensor:
Expand All @@ -77,13 +73,8 @@ def rnd_loss(self, obs: torch.Tensor) -> torch.Tensor:

def _rnd_error(self, obs: torch.Tensor) -> torch.Tensor:
x = obs.float() / 255.0
if self.observation_normalization:
self.obs_rescaler.update(x)
x = self.obs_rescaler.rescale(x)

with torch.no_grad():
target = self.target_net(x)
pred = self.predict_net(x)
err = (pred - target).square().mean(-1, keepdim=True)

tgt = self.tgt_net(x)
prd = self.prd_net(x)
err = (prd - tgt).square().mean(-1, keepdim=True)
return err
2 changes: 1 addition & 1 deletion examples/atari/ppo/conf/conf_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ train_device: "cuda:0"
infer_device: "cuda:1"

env: "PongNoFrameskip-v4"
max_episode_steps: 2700
max_episode_steps: 3000

num_train_rollouts: 32
num_train_workers: 16
Expand Down
Loading

0 comments on commit f80dbe3

Please sign in to comment.