Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #38 from facebookresearch/v0.1
Browse files Browse the repository at this point in the history
Add ppo_rnd algorithm implementation
  • Loading branch information
xiaomengy committed Apr 11, 2022
2 parents 5ff8ac5 + f4449a6 commit 132d0c5
Show file tree
Hide file tree
Showing 8 changed files with 597 additions and 94 deletions.
32 changes: 32 additions & 0 deletions examples/atari/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class AtariBackbone(nn.Module):

def __init__(self) -> None:
super().__init__()

layers = []
layers.append(nn.Conv2d(4, 32, kernel_size=8, stride=4))
layers.append(nn.ReLU())
layers.append(nn.Conv2d(32, 64, kernel_size=4, stride=2))
layers.append(nn.ReLU())
layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1))
layers.append(nn.ReLU())
layers.append(nn.Flatten())
layers.append(nn.Linear(3136, 512))
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)

@property
def output_dim(self) -> int:
return 512

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
18 changes: 5 additions & 13 deletions examples/atari/dqn/atari_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import rlmeta.core.remote as remote

from examples.atari.backbone import AtariBackbone
from rlmeta.agents.dqn.dqn_model import DQNModel
from rlmeta.core.rescalers import SqrtRescaler
from rlmeta.core.types import NestedTensor
Expand All @@ -25,19 +26,10 @@ def __init__(self, action_dim: int, dueling_dqn: bool = True) -> None:
self.action_dim = action_dim
self.dueling_dqn = dueling_dqn

layers = []
layers.append(nn.Conv2d(4, 32, kernel_size=8, stride=4))
layers.append(nn.ReLU())
layers.append(nn.Conv2d(32, 64, kernel_size=4, stride=2))
layers.append(nn.ReLU())
layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1))
layers.append(nn.ReLU())
layers.append(nn.Flatten())
layers.append(nn.Linear(3136, 512))
layers.append(nn.ReLU())
self.backbone = nn.Sequential(*layers)
self.linear_a = nn.Linear(512, self.action_dim)
self.linear_v = nn.Linear(512, 1) if dueling_dqn else None
self.backbone = AtariBackbone()
self.linear_a = nn.Linear(self.backbone.output_dim, self.action_dim)
self.linear_v = nn.Linear(self.backbone.output_dim,
1) if dueling_dqn else None

def forward(self, obs: torch.Tensor) -> torch.Tensor:
x = obs.float() / 255.0
Expand Down
18 changes: 4 additions & 14 deletions examples/atari/ppo/atari_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import rlmeta.core.remote as remote

from examples.atari.backbone import AtariBackbone
from rlmeta.agents.ppo.ppo_model import PPOModel


Expand All @@ -19,20 +20,9 @@ class AtariPPOModel(PPOModel):
def __init__(self, action_dim: int) -> None:
super().__init__()
self.action_dim = action_dim

layers = []
layers.append(nn.Conv2d(4, 32, kernel_size=8, stride=4))
layers.append(nn.ReLU())
layers.append(nn.Conv2d(32, 64, kernel_size=4, stride=2))
layers.append(nn.ReLU())
layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1))
layers.append(nn.ReLU())
layers.append(nn.Flatten())
layers.append(nn.Linear(3136, 512))
layers.append(nn.ReLU())
self.backbone = nn.Sequential(*layers)
self.linear_p = nn.Linear(512, self.action_dim)
self.linear_v = nn.Linear(512, 1)
self.backbone = AtariBackbone()
self.linear_p = nn.Linear(self.backbone.output_dim, self.action_dim)
self.linear_v = nn.Linear(self.backbone.output_dim, 1)

def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = obs.float() / 255.0
Expand Down
128 changes: 128 additions & 0 deletions examples/atari/ppo/atari_ppo_rnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import json
import logging
import time

import hydra

import torch
import torch.multiprocessing as mp

import rlmeta.envs.atari_wrappers as atari_wrappers
import rlmeta.envs.gym_wrappers as gym_wrappers
import rlmeta.utils.hydra_utils as hydra_utils
import rlmeta.utils.remote_utils as remote_utils

from examples.atari.ppo.atari_ppo_rnd_model import AtariPPORNDModel
from rlmeta.agents.agent import AgentFactory
from rlmeta.agents.ppo.ppo_rnd_agent import PPORNDAgent
from rlmeta.core.controller import Phase, Controller
from rlmeta.core.loop import LoopList, ParallelLoop
from rlmeta.core.model import wrap_downstream_model
from rlmeta.core.replay_buffer import ReplayBuffer, make_remote_replay_buffer
from rlmeta.core.server import Server, ServerList


@hydra.main(config_path="./conf", config_name="conf_ppo")
def main(cfg):
logging.info(hydra_utils.config_to_json(cfg))

env = atari_wrappers.make_atari(cfg.env)
train_model = AtariPPORNDModel(env.action_space.n).to(cfg.train_device)
optimizer = torch.optim.Adam(train_model.parameters(), lr=cfg.lr)

infer_model = copy.deepcopy(train_model).to(cfg.infer_device)

ctrl = Controller()
rb = ReplayBuffer(cfg.replay_buffer_size)

m_server = Server(cfg.m_server_name, cfg.m_server_addr)
r_server = Server(cfg.r_server_name, cfg.r_server_addr)
c_server = Server(cfg.c_server_name, cfg.c_server_addr)
m_server.add_service(infer_model)
r_server.add_service(rb)
c_server.add_service(ctrl)
servers = ServerList([m_server, r_server, c_server])

a_model = wrap_downstream_model(train_model, m_server)
t_model = remote_utils.make_remote(infer_model, m_server)
e_model = remote_utils.make_remote(infer_model, m_server)

a_ctrl = remote_utils.make_remote(ctrl, c_server)
t_ctrl = remote_utils.make_remote(ctrl, c_server)
e_ctrl = remote_utils.make_remote(ctrl, c_server)

a_rb = make_remote_replay_buffer(rb, r_server, prefetch=cfg.prefetch)
t_rb = make_remote_replay_buffer(rb, r_server)

env_fac = gym_wrappers.AtariWrapperFactory(
cfg.env, max_episode_steps=cfg.max_episode_steps)

agent = PPORNDAgent(a_model,
replay_buffer=a_rb,
controller=a_ctrl,
optimizer=optimizer,
batch_size=cfg.batch_size,
learning_starts=cfg.get("learning_starts", None),
push_every_n_steps=cfg.push_every_n_steps)
t_agent_fac = AgentFactory(PPORNDAgent, t_model, replay_buffer=t_rb)
e_agent_fac = AgentFactory(PPORNDAgent, e_model, deterministic_policy=False)

t_loop = ParallelLoop(env_fac,
t_agent_fac,
t_ctrl,
running_phase=Phase.TRAIN,
should_update=True,
num_rollouts=cfg.num_train_rollouts,
num_workers=cfg.num_train_workers,
seed=cfg.train_seed)
e_loop = ParallelLoop(env_fac,
e_agent_fac,
e_ctrl,
running_phase=Phase.EVAL,
should_update=False,
num_rollouts=cfg.num_eval_rollouts,
num_workers=cfg.num_eval_workers,
seed=cfg.eval_seed)
loops = LoopList([t_loop, e_loop])

servers.start()
loops.start()
agent.connect()

start_time = time.perf_counter()
for epoch in range(cfg.num_epochs):
stats = agent.train(cfg.steps_per_epoch)
cur_time = time.perf_counter() - start_time
info = f"T Epoch {epoch}"
if cfg.table_view:
logging.info("\n\n" + stats.table(info, time=cur_time) + "\n")
else:
logging.info(
stats.json(info, phase="Train", epoch=epoch, time=cur_time))
time.sleep(1)

stats = agent.eval(cfg.num_eval_episodes)
cur_time = time.perf_counter() - start_time
info = f"E Epoch {epoch}"
if cfg.table_view:
logging.info("\n\n" + stats.table(info, time=cur_time) + "\n")
else:
logging.info(
stats.json(info, phase="Eval", epoch=epoch, time=cur_time))
time.sleep(1)

torch.save(train_model.state_dict(), f"ppo_rnd_agent-{epoch}.pth")

loops.terminate()
servers.terminate()


if __name__ == "__main__":
mp.set_start_method("spawn")
main()
89 changes: 89 additions & 0 deletions examples/atari/ppo/atari_ppo_rnd_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import rlmeta.core.remote as remote

from examples.atari.backbone import AtariBackbone
from rlmeta.agents.ppo.ppo_rnd_model import PPORNDModel
from rlmeta.core.rescalers import MomentsRescaler
from rlmeta.core.types import NestedTensor


class AtariPPORNDModel(PPORNDModel):

def __init__(self,
action_dim: int,
observation_normalization: bool = False) -> None:
super().__init__()

self.action_dim = action_dim
self.observation_normalization = observation_normalization
if self.observation_normalization:
self.obs_rescaler = MomentsRescaler(size=(4, 84, 84))

self.policy_net = AtariBackbone()
self.target_net = AtariBackbone()
self.predict_net = AtariBackbone()
self.linear_p = nn.Linear(self.policy_net.output_dim, self.action_dim)
self.linear_ext_v = nn.Linear(self.policy_net.output_dim, 1)
self.linear_int_v = nn.Linear(self.policy_net.output_dim, 1)

def forward(
self, obs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = obs.float() / 255.0
h = self.policy_net(x)
p = self.linear_p(h)
logpi = F.log_softmax(p, dim=-1)
ext_v = self.linear_ext_v(h)
int_v = self.linear_int_v(h)

return logpi, ext_v, int_v

@remote.remote_method(batch_size=128)
def act(
self, obs: torch.Tensor, deterministic_policy: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
device = next(self.parameters()).device

with torch.no_grad():
x = obs.to(device)
d = deterministic_policy.to(device)
logpi, ext_v, int_v = self.forward(x)

greedy_action = logpi.argmax(-1, keepdim=True)
sample_action = logpi.exp().multinomial(1, replacement=True)
action = torch.where(d, greedy_action, sample_action)
logpi = logpi.gather(dim=-1, index=action)

return action.cpu(), logpi.cpu(), ext_v.cpu(), int_v.cpu()

@remote.remote_method(batch_size=None)
def intrinsic_reward(self, obs: torch.Tensor) -> torch.Tensor:
device = next(self.parameters()).device
reward = self._rnd_error(obs.to(device))
return reward.cpu()

def rnd_loss(self, obs: torch.Tensor) -> torch.Tensor:
return self._rnd_error(obs).mean() * 0.5

def _rnd_error(self, obs: torch.Tensor) -> torch.Tensor:
x = obs.float() / 255.0
if self.observation_normalization:
self.obs_rescaler.update(x)
x = self.obs_rescaler.rescale(x)

with torch.no_grad():
target = self.target_net(x)
pred = self.predict_net(x)
err = (pred - target).square().mean(-1, keepdim=True)

return err
Loading

0 comments on commit 132d0c5

Please sign in to comment.