diff --git a/examples/a2c_atari.py b/examples/a2c_atari.py index d0cd3a1..f9ae72c 100644 --- a/examples/a2c_atari.py +++ b/examples/a2c_atari.py @@ -9,7 +9,7 @@ from torch.optim import RMSprop -def config(game: str = "Breakout") -> Config: +def config(envname: str = "Breakout") -> Config: c = Config() c.set_env(lambda: Atari(game, frame_stack=False)) c.set_optimizer(lambda params: RMSprop(params, lr=7e-4, alpha=0.99, eps=1e-5)) diff --git a/examples/a2c_cart_pole.py b/examples/a2c_cart_pole.py index 88c7b82..b748b30 100644 --- a/examples/a2c_cart_pole.py +++ b/examples/a2c_cart_pole.py @@ -1,12 +1,13 @@ import os import rainy from rainy.utils.cli import run_cli -from rainy.envs import MultiProcEnv +from rainy.envs import ClassicControl, MultiProcEnv from torch.optim import Adam -def config() -> rainy.Config: +def config(envname: str = "CartPole-v0", rnn: bool = False) -> rainy.Config: c = rainy.Config() + c.set_env(lambda: ClassicControl(envname)) c.max_steps = int(1e6) c.nworkers = 12 c.nsteps = 5 @@ -18,9 +19,13 @@ def config() -> rainy.Config: c.eval_deterministic = True c.eval_freq = c.max_steps // 10 c.entropy_weight = 0.001 - # c.set_net_fn('actor-critic', rainy.net.actor_critic.fc_shared(rnn=rainy.net.GruBlock)) + if rnn: + c.set_net_fn( + "actor-critic", rainy.net.actor_critic.fc_shared(rnn=rainy.net.GruBlock) + ) return c if __name__ == "__main__": - run_cli(config, rainy.agents.A2CAgent, script_path=os.path.realpath(__file__)) + options = [click.Option(["--rnn"], is_flag=True)] + run_cli(config, rainy.agents.A2CAgent, os.path.realpath(__file__), options) diff --git a/examples/acktr_atari.py b/examples/acktr_atari.py index fc395be..3b3a4f8 100644 --- a/examples/acktr_atari.py +++ b/examples/acktr_atari.py @@ -12,7 +12,7 @@ } -def config(game: str = "Breakout") -> Config: +def config(envname: str = "Breakout") -> Config: c = Config() c.set_env(lambda: Atari(game, frame_stack=False)) c.set_optimizer(kfac.default_sgd(eta_max=0.2)) diff --git a/examples/acktr_cart_pole.py b/examples/acktr_cart_pole.py index 7a19235..5cba228 100644 --- a/examples/acktr_cart_pole.py +++ b/examples/acktr_cart_pole.py @@ -2,7 +2,7 @@ from rainy import Config from rainy.agents import ACKTRAgent import rainy.utils.cli as cli -from rainy.envs import MultiProcEnv +from rainy.envs import ClassicControl, MultiProcEnv from rainy.lib import kfac @@ -13,8 +13,9 @@ } -def config() -> Config: +def config(envname: str = "CartPole-v0") -> Config: c = Config() + c.set_env(lambda: ClassicControl(envname)) c.max_steps = int(4e5) c.nworkers = 12 c.nsteps = 20 diff --git a/examples/aoc_atari.py b/examples/aoc_atari.py index e23337f..0f3a324 100644 --- a/examples/aoc_atari.py +++ b/examples/aoc_atari.py @@ -9,7 +9,7 @@ from torch.optim import RMSprop -def config(game: str = "Breakout") -> Config: +def config(envname: str = "Breakout") -> Config: c = Config() c.set_env(lambda: Atari(game, frame_stack=False)) c.set_optimizer(lambda params: RMSprop(params, lr=7e-4, alpha=0.99, eps=1e-5)) diff --git a/examples/aoc_cart_pole.py b/examples/aoc_cart_pole.py index 8c87701..b126cf5 100644 --- a/examples/aoc_cart_pole.py +++ b/examples/aoc_cart_pole.py @@ -1,12 +1,13 @@ import os import rainy from rainy.utils.cli import run_cli -from rainy.envs import MultiProcEnv +from rainy.envs import ClassicControl, MultiProcEnv from torch import optim -def config() -> rainy.Config: +def config(envname: str = "CartPole-v0") -> rainy.Config: c = rainy.Config() + c.set_env(lambda: ClassicControl(envname)) c.max_steps = int(4e5) c.nworkers = 12 c.nsteps = 5 diff --git a/examples/bootdqn_cart_pole.py b/examples/bootdqn_cart_pole.py new file mode 100644 index 0000000..f39bf1e --- /dev/null +++ b/examples/bootdqn_cart_pole.py @@ -0,0 +1,42 @@ +import click +import os +from rainy import Config +from rainy.agents import EpisodicBootDQNAgent +from rainy.envs import ClassicControl +from rainy.lib import explore +from rainy.net import bootstrap +from rainy.replay import BootDQNReplayFeed, UniformReplayBuffer +import rainy.utils.cli as cli +from torch import optim + + +def config( + envname: str = "CartPole-v0", + max_steps: int = 1000000, + rpf: bool = False, + replay_prob: float = 0.5, + prior_scale: float = 1.0, +) -> Config: + c = Config() + c.set_optimizer(lambda params: optim.Adam(params)) + c.set_explorer(lambda: explore.Greedy()) + c.set_explorer(lambda: explore.Greedy(), key="eval") + c.set_env(lambda: ClassicControl(envname)) + c.max_steps = max_steps + c.episode_log_freq = 100 + c.replay_prob = replay_prob + if rpf: + c.set_net_fn("bootdqn", bootstrap.rpf_fc_separated(10, prior_scale=prior_scale)) + c.set_replay_buffer( + lambda capacity: UniformReplayBuffer(BootDQNReplayFeed, capacity=capacity) + ) + return c + + +if __name__ == "__main__": + options = [ + click.Option(["--rpf"], is_flag=True), + click.Option(["--replay-prob", "-RP"], type=float, default=0.5), + click.Option(["--prior-scale", "-PS"], type=float, default=1.0), + ] + cli.run_cli(config, EpisodicBootDQNAgent, os.path.realpath(__file__), options) diff --git a/examples/bootdqn_deepsea.py b/examples/bootdqn_deepsea.py new file mode 100644 index 0000000..8428e39 --- /dev/null +++ b/examples/bootdqn_deepsea.py @@ -0,0 +1,42 @@ +import click +import os +from rainy import Config +from rainy.agents import EpisodicBootDQNAgent +from rainy.envs import DeepSea +from rainy.lib import explore +from rainy.net import bootstrap +from rainy.replay import BootDQNReplayFeed, UniformReplayBuffer +import rainy.utils.cli as cli +from torch import optim + + +def config( + max_steps: int = 100000, + size: int = 20, + rpf: bool = False, + replay_prob: float = 0.5, + prior_scale: float = 1.0, +) -> Config: + c = Config() + c.set_optimizer(lambda params: optim.Adam(params)) + c.set_explorer(lambda: explore.Greedy()) + c.set_explorer(lambda: explore.Greedy(), key="eval") + c.set_env(lambda: DeepSea(size)) + c.max_steps = max_steps + c.episode_log_freq = 100 + c.replay_prob = replay_prob + if rpf: + c.set_net_fn("bootdqn", bootstrap.rpf_fc_separated(10, prior_scale=prior_scale)) + c.set_replay_buffer( + lambda capacity: UniformReplayBuffer(BootDQNReplayFeed, capacity=capacity) + ) + return c + + +if __name__ == "__main__": + options = [ + click.Option(["--rpf"], is_flag=True), + click.Option(["--replay-prob", "-RP"], type=float, default=0.5), + click.Option(["--prior-scale", "-PS"], type=float, default=1.0), + ] + cli.run_cli(config, EpisodicBootDQNAgent, os.path.realpath(__file__), options) diff --git a/examples/ddqn_atari.py b/examples/ddqn_atari.py index fd8c56f..16fad00 100644 --- a/examples/ddqn_atari.py +++ b/examples/ddqn_atari.py @@ -7,14 +7,14 @@ from torch.optim import RMSprop -def config(game: str = "Breakout") -> Config: +def config(envname: str = "Breakout") -> Config: c = Config() c.set_env(lambda: Atari(game)) c.set_optimizer( lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True) ) c.set_explorer(lambda: EpsGreedy(1.0, LinearCooler(1.0, 0.1, int(1e6)))) - c.set_net_fn("value", net.value.dqn_conv()) + c.set_net_fn("dqn", net.value.dqn_conv()) c.replay_size = int(1e6) c.batch_size = 32 c.train_start = 50000 diff --git a/examples/ddqn_cart_pole.py b/examples/ddqn_cart_pole.py index 97b36fa..b732907 100644 --- a/examples/ddqn_cart_pole.py +++ b/examples/ddqn_cart_pole.py @@ -1,11 +1,13 @@ import os from rainy import Config from rainy.agents import DoubleDQNAgent +from rainy.envs import ClassicControl import rainy.utils.cli as cli -def config() -> Config: +def config(envname: str = "CartPole-v0") -> Config: c = Config() + c.set_env(lambda: ClassicControl(envname)) c.max_steps = 100000 return c diff --git a/examples/dqn_atari.py b/examples/dqn_atari.py index 479ea3a..082e9db 100644 --- a/examples/dqn_atari.py +++ b/examples/dqn_atari.py @@ -7,14 +7,14 @@ from torch.optim import RMSprop -def config(game: str = "Breakout") -> Config: +def config(envname: str = "Breakout") -> Config: c = Config() c.set_env(lambda: Atari(game)) c.set_optimizer( lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True) ) c.set_explorer(lambda: EpsGreedy(1.0, LinearCooler(1.0, 0.1, int(1e6)))) - c.set_net_fn("value", net.value.dqn_conv()) + c.set_net_fn("dqn", net.value.dqn_conv()) c.replay_size = int(1e6) c.replay_batch_size = 32 c.train_start = 50000 diff --git a/examples/dqn_cart_pole.py b/examples/dqn_cart_pole.py index 8874873..07c8c1d 100644 --- a/examples/dqn_cart_pole.py +++ b/examples/dqn_cart_pole.py @@ -1,13 +1,15 @@ import os from rainy import Config from rainy.agents import DQNAgent +from rainy.envs import ClassicControl import rainy.utils.cli as cli -def config() -> Config: +def config(envname: str = "CartPole-v0", max_steps: int = 100000) -> Config: c = Config() - c.max_steps = 100000 - c.eval_freq = 100 + c.set_env(lambda: ClassicControl(envname)) + c.max_steps = max_steps + c.episode_log_freq = 100 return c diff --git a/examples/episodic_dqn_cart_pole.py b/examples/episodic_dqn_cart_pole.py new file mode 100644 index 0000000..a8e2905 --- /dev/null +++ b/examples/episodic_dqn_cart_pole.py @@ -0,0 +1,17 @@ +import os +from rainy import Config +from rainy.agents import EpisodicDQNAgent +from rainy.envs import ClassicControl +import rainy.utils.cli as cli + + +def config(envname: str = "CartPole-v0") -> Config: + c = Config() + c.set_env(lambda: ClassicControl(envname)) + c.max_steps = 100000 + c.episode_log_freq = 100 + return c + + +if __name__ == "__main__": + cli.run_cli(config, EpisodicDQNAgent, script_path=os.path.realpath(__file__)) diff --git a/examples/ppo_atari.py b/examples/ppo_atari.py index 9913b15..32f276a 100644 --- a/examples/ppo_atari.py +++ b/examples/ppo_atari.py @@ -6,7 +6,7 @@ from torch.optim import Adam -def config(game: str = "Breakout") -> Config: +def config(envname: str = "Breakout") -> Config: c = Config() c.set_env(lambda: Atari(game, frame_stack=False)) # c.set_net_fn('actor-critic', net.actor_critic.ac_conv(rnn=net.GruBlock)) diff --git a/examples/ppo_cart_pole.py b/examples/ppo_cart_pole.py index 510c55a..a4a79c5 100644 --- a/examples/ppo_cart_pole.py +++ b/examples/ppo_cart_pole.py @@ -1,12 +1,13 @@ import os import rainy from rainy.utils.cli import run_cli -from rainy.envs import MultiProcEnv +from rainy.envs import ClassicControl, MultiProcEnv from torch.optim import Adam -def config() -> rainy.Config: +def config(envname: str = "CartPole-v0") -> rainy.Config: c = rainy.Config() + c.set_env(lambda: ClassicControl(envname)) c.max_steps = int(1e5) c.nworkers = 8 c.nsteps = 32 diff --git a/examples/ppo_flicker_atari.py b/examples/ppo_flicker_atari.py index 0fad61d..f111c37 100644 --- a/examples/ppo_flicker_atari.py +++ b/examples/ppo_flicker_atari.py @@ -5,7 +5,7 @@ import rainy.utils.cli as cli -def config(game: str = "Breakout") -> rainy.Config: +def config(envname: str = "Breakout") -> rainy.Config: c = ppo_atari.config(game) c.set_env(lambda: Atari(game, flicker_frame=True, frame_stack=False)) c.set_parallel_env(atari_parallel(frame_stack=False)) diff --git a/rainy/agents/__init__.py b/rainy/agents/__init__.py index bac9174..784c088 100644 --- a/rainy/agents/__init__.py +++ b/rainy/agents/__init__.py @@ -2,8 +2,9 @@ from .acktr import ACKTRAgent from .aoc import AOCAgent from .base import Agent, EpisodeResult, NStepParallelAgent, OneStepAgent +from .bootdqn import BootDQNAgent, EpisodicBootDQNAgent from .ddpg import DDPGAgent -from .dqn import DQNAgent, DoubleDQNAgent +from .dqn import DQNAgent, DoubleDQNAgent, EpisodicDQNAgent from .ppo import PPOAgent from .sac import SACAgent from .td3 import TD3Agent diff --git a/rainy/agents/bootdqn.py b/rainy/agents/bootdqn.py new file mode 100644 index 0000000..887b8e6 --- /dev/null +++ b/rainy/agents/bootdqn.py @@ -0,0 +1,108 @@ +""" +This module has an implementation of Bootstrapped DQN, which is described in +- Deep Exploration via Bootstrapped DQN(https://arxiv.org/abs/1602.04621) +- Randomized Prior Functions for Deep Reinforcement Learning(https://arxiv.org/abs/1806.03335) +""" +import copy +import numpy as np +import torch +from torch import Tensor +from torch.nn import functional as F +from typing import Tuple +from .base import OneStepAgent +from ..config import Config +from ..prelude import Action, Array, State +from ..replay import BootDQNReplayFeed + + +class EpisodicBootDQNAgent(OneStepAgent): + SAVED_MEMBERS = "net", "policy", "total_steps" + + def __init__(self, config: Config) -> None: + super().__init__(config) + if not self.env.spec.is_discrete(): + raise RuntimeError("DQN only supports discrete action space.") + self.net = config.net("bootdqn") + self.optimizer = config.optimizer(self.net.parameters()) + self.policy = config.explorer() + self.eval_policy = config.explorer(key="eval") + self.replay = config.replay_buffer() + self.replay.allow_overlap = True + self.active_head = 0 + if self.replay.feed is not BootDQNReplayFeed: + raise RuntimeError("BootDQNAgent needs BootDQNReplayFeed") + + def set_mode(self, train: bool = True) -> None: + self.net.train(mode=train) + + @torch.no_grad() + def eval_action(self, state: Array) -> Action: + return self.eval_policy.select_action(state, self.net).item() # type: ignore + + def step(self, state: State) -> Tuple[State, float, bool, dict]: + with torch.no_grad(): + qs = self.net.q_i_s(self.active_head, self.env.extract(state)).detach() + action = self.policy.select_from_value(qs).item() + next_state, reward, done, info = self.env.step(action) + self._append_to_replay(state, action, reward, next_state, done) + if done: + self._train() + self.active_head = np.random.randint(self.config.num_ensembles) + return next_state, reward, done, info + + def _append_to_replay(self, *transition) -> None: + n_ens = self.config.num_ensembles + mask = np.random.uniform(0, 1, n_ens) < self.config.replay_prob + self.replay.append(*transition, mask) + + @torch.no_grad() + def _q_next(self, next_states: Array) -> Tensor: + return self.net(next_states).max(axis=-1)[0] + + def _train(self) -> None: + gamma = self.config.discount_factor + obs = self.replay.sample(self.config.replay_batch_size) + obs = [ob.to_array(self.env.extract) for ob in obs] + states, actions, rewards, next_states, done, mask = map(np.asarray, zip(*obs)) + q_next = self._q_next(next_states) + r = self.tensor(rewards).view(-1, 1) + q_target = r + q_next * self.tensor(1.0 - done).mul_(gamma).view(-1, 1) + q_current = self.net.q_s_a(states, actions) + loss = F.mse_loss(q_current, q_target, reduction="none") + masked_loss = loss.masked_select(self.tensor(mask, dtype=torch.bool)).mean() + self._backward(masked_loss, self.optimizer, self.net.parameters()) + self.network_log(q_value=q_current.mean().item(), value_loss=masked_loss.item()) + + +class BootDQNAgent(EpisodicBootDQNAgent): + SAVED_MEMBERS = "net", "policy", "total_steps", "target_net" + + def __init__(self, config: Config) -> None: + super().__init__(config) + self.target_net = copy.deepcopy(self.net) + self.replay.allow_overlap = False + + def step(self, state: State) -> Tuple[State, float, bool, dict]: + train_started = self.total_steps > self.config.train_start + if train_started: + with torch.no_grad(): + qs = self.net.q_i_s(self.active_head, self.env.extract(state)).detach() + action = self.policy.select_from_value(qs).item() + else: + action = self.env.spec.random_action() + next_state, reward, done, info = self.env.step(action) + self._append_to_replay(state, action, reward, next_state, done) + if done: + self.active_head = np.random.randint(self.config.num_ensembles) + if train_started: + self._train() + return next_state, reward, done, info + + @torch.no_grad() + def _q_next(self, next_states: Array) -> Tensor: + return self.net(next_states).max(axis=-1)[0] + + def _train(self) -> None: + super()._train + if (self.update_steps + 1) % self.config.sync_freq == 0: + self.target_net.load_state_dict(self.net.state_dict()) diff --git a/rainy/agents/ddpg.py b/rainy/agents/ddpg.py index c1264b2..783797c 100644 --- a/rainy/agents/ddpg.py +++ b/rainy/agents/ddpg.py @@ -53,7 +53,7 @@ def _q_next(self, next_states: Array) -> Tensor: def _train(self) -> None: obs = self.replay.sample(self.config.replay_batch_size) - obs = [ob.to_ndarray(self.env.extract) for ob in obs] + obs = [ob.to_array(self.env.extract) for ob in obs] states, actions, rewards, next_states, done = map(np.asarray, zip(*obs)) mask = self.config.device.tensor(1.0 - done) q_next = self._q_next(next_states) diff --git a/rainy/agents/dqn.py b/rainy/agents/dqn.py index 68e4eec..c7bd790 100644 --- a/rainy/agents/dqn.py +++ b/rainy/agents/dqn.py @@ -1,27 +1,33 @@ from copy import deepcopy import numpy as np import torch -from torch import nn, Tensor +from torch import Tensor from torch.nn import functional as F from typing import Tuple from .base import OneStepAgent from ..config import Config from ..prelude import Action, Array, State +from ..replay import DQNReplayFeed -class DQNAgent(OneStepAgent): - SAVED_MEMBERS = "net", "target_net", "policy", "total_steps" +class EpisodicDQNAgent(OneStepAgent): + """A DQN variant, which has no target network and + updates the target only once per episode. + """ + + SAVED_MEMBERS = "net", "policy", "total_steps" def __init__(self, config: Config) -> None: super().__init__(config) if not self.env.spec.is_discrete(): raise RuntimeError("DQN only supports discrete action space.") - self.net = config.net("value") - self.target_net = deepcopy(self.net) + self.net = config.net("dqn") self.optimizer = config.optimizer(self.net.parameters()) self.policy = config.explorer() self.eval_policy = config.explorer(key="eval") self.replay = config.replay_buffer() + if self.replay.feed is not DQNReplayFeed: + raise RuntimeError("DQNAgent needs DQNReplayFeed") self.batch_indices = config.device.indices(config.replay_batch_size) def set_mode(self, train: bool = True) -> None: @@ -32,34 +38,55 @@ def eval_action(self, state: Array) -> Action: return self.eval_policy.select_action(state, self.net).item() # type: ignore def step(self, state: State) -> Tuple[State, float, bool, dict]: - train_started = self.total_steps > self.config.train_start - if train_started: - action = self.policy.select_action(self.env.extract(state), self.net).item() - else: - action = self.env.spec.random_action() + action = self.policy.select_action(self.env.extract(state), self.net).item() next_state, reward, done, info = self.env.step(action) self.replay.append(state, action, reward, next_state, done) - if train_started: + if done: self._train() return next_state, reward, done, info @torch.no_grad() def _q_next(self, next_states: Array) -> Tensor: - return self.target_net(next_states).max(1)[0] + return self.net(next_states).max(axis=-1)[0] def _train(self) -> None: obs = self.replay.sample(self.config.replay_batch_size) - obs = [ob.to_ndarray(self.env.extract) for ob in obs] + obs = [ob.to_array(self.env.extract) for ob in obs] states, actions, rewards, next_states, done = map(np.asarray, zip(*obs)) q_next = self._q_next(next_states).mul_(self.tensor(1.0 - done)) q_target = self.tensor(rewards).add_(q_next.mul_(self.config.discount_factor)) q_current = self.net(states)[self.batch_indices, actions] loss = F.mse_loss(q_current, q_target) - self.optimizer.zero_grad() - loss.backward() - nn.utils.clip_grad_norm_(self.net.parameters(), self.config.grad_clip) - self.optimizer.step() + self._backward(loss, self.optimizer, self.net.parameters()) self.network_log(q_value=q_current.mean().item(), value_loss=loss.item()) + + +class DQNAgent(EpisodicDQNAgent): + + SAVED_MEMBERS = "net", "policy", "total_steps", "target_net" + + def __init__(self, config: Config) -> None: + super().__init__(config) + self.target_net = deepcopy(self.net) + + def step(self, state: State) -> Tuple[State, float, bool, dict]: + train_started = self.total_steps > self.config.train_start + if train_started: + action = self.policy.select_action(self.env.extract(state), self.net).item() + else: + action = self.env.spec.random_action() + next_state, reward, done, info = self.env.step(action) + self.replay.append(state, action, reward, next_state, done) + if train_started: + self._train() + return next_state, reward, done, info + + @torch.no_grad() + def _q_next(self, next_states: Array) -> Tensor: + return self.target_net(next_states).max(axis=-1)[0] + + def _train(self): + super()._train() if (self.update_steps + 1) % self.config.sync_freq == 0: self.target_net.load_state_dict(self.net.state_dict()) @@ -70,5 +97,5 @@ def _q_next(self, next_states: Array) -> Tensor: """Returns Q values of next_states, supposing torch.no_grad() is called """ q_next = self.target_net(next_states) - q_values = self.net.q_values(next_states, nostack=True) - return q_next[self.batch_indices, q_values.argmax(dim=-1)] + q_value = self.net.q_value(next_states, nostack=True) + return q_next[self.batch_indices, q_value.argmax(dim=-1)] diff --git a/rainy/agents/sac.py b/rainy/agents/sac.py index 174aa56..2631ed6 100644 --- a/rainy/agents/sac.py +++ b/rainy/agents/sac.py @@ -103,7 +103,7 @@ def _q_next(self, next_states: Tensor, alpha: float) -> Tensor: def _train(self) -> None: obs = self.replay.sample(self.config.replay_batch_size) - obs = [ob.to_ndarray(self.env.extract) for ob in obs] + obs = [ob.to_array(self.env.extract) for ob in obs] states, actions, rewards, next_states, done = map(np.asarray, zip(*obs)) q1, q2, policy = self.net(states, actions) diff --git a/rainy/agents/td3.py b/rainy/agents/td3.py index 38fc638..c0a4fe2 100644 --- a/rainy/agents/td3.py +++ b/rainy/agents/td3.py @@ -37,7 +37,7 @@ def _q_next(self, next_states: Array) -> Tensor: def _train(self) -> None: obs = self.replay.sample(self.config.replay_batch_size) - obs = [ob.to_ndarray(self.env.extract) for ob in obs] + obs = [ob.to_array(self.env.extract) for ob in obs] states, actions, rewards, next_states, done = map(np.asarray, zip(*obs)) mask = self.config.device.tensor(1.0 - done) q_next = self._q_next(next_states).squeeze_() diff --git a/rainy/config.py b/rainy/config.py index 5ff304a..5e7a296 100644 --- a/rainy/config.py +++ b/rainy/config.py @@ -1,8 +1,8 @@ from torch import nn from torch.optim import Optimizer, RMSprop from typing import Callable, Dict, List, Optional, Sequence -from .envs import ClassicalControl, DummyParallelEnv, EnvExt, EnvGen, ParallelEnv -from .net import actor_critic, deterministic, option_critic, sac, value +from .envs import ClassicControl, DummyParallelEnv, EnvExt, EnvGen, ParallelEnv +from .net import actor_critic, bootstrap, deterministic, option_critic, sac, value from .net.prelude import NetFn from .lib.explore import DummyCooler, Cooler, LinearCooler, Explorer, EpsGreedy from .lib import mpi @@ -26,7 +26,7 @@ def __init__(self) -> None: self.eval_deterministic = True # Replay buffer - self.replay_batch_size = 10 + self.replay_batch_size = 64 self.replay_size = 10000 self.train_start = 1000 self.__replay: Callable[ @@ -38,12 +38,16 @@ def __init__(self) -> None: self.parallel_seeds: List[int] = [] # For DQN-like algorithms - self.sync_freq = 200 + self.sync_freq = 1000 self.__explore: Dict[Optional[str], Callable[[], Explorer]] = { None: lambda: EpsGreedy(1.0, LinearCooler(1.0, 0.1, 10000)), "eval": lambda: EpsGreedy(0.01, DummyCooler(0.01)), } + # For BootDQN + self.num_ensembles = 10 + self.replay_prob = 0.5 + # Reward scaling # Currently only used by SAC self.reward_scale = 1.0 @@ -100,7 +104,8 @@ def __init__(self) -> None: # Default Networks self.__net: Dict[str, NetFn] = { - "value": value.fc(), + "dqn": value.fc(), + "bootdqn": bootstrap.fc_separated(10), "actor-critic": actor_critic.fc_shared(), "ddpg": deterministic.fc_seprated(), "td3": deterministic.td3_fc_seprated(), @@ -110,7 +115,7 @@ def __init__(self) -> None: # Environments self.eval_times = 1 - self.__env = lambda: ClassicalControl() + self.__env = lambda: ClassicControl() self.__eval_env: Optional[EnvExt] = None self.__paralle_env = lambda env_gen, num_w: DummyParallelEnv(env_gen, num_w) diff --git a/rainy/envs/__init__.py b/rainy/envs/__init__.py index 8e560c7..087ac22 100644 --- a/rainy/envs/__init__.py +++ b/rainy/envs/__init__.py @@ -2,6 +2,7 @@ import gym from typing import Callable, Optional from .atari_wrappers import LazyFrames, make_atari, wrap_deepmind +from .deepsea import DeepSea as DeepSeaGymEnv from .ext import EnvExt, EnvSpec from .monitor import RewardMonitor from .obs_wrappers import AddTimeStep, TransposeObs @@ -96,11 +97,20 @@ def __wrap(env_gen: EnvGen, num_workers: int) -> ParallelEnv: return __wrap -class ClassicalControl(EnvExt): - def __init__(self, name: str = "CartPole-v0", max_steps: int = 200) -> None: +class ClassicControl(EnvExt): + def __init__( + self, name: str = "CartPole-v0", max_steps: Optional[int] = None + ) -> None: self.name = name super().__init__(gym.make(name)) - self._env._max_episode_steps = max_steps + if max_steps is not None: + self._env._max_episode_steps = max_steps + + +class DeepSea(EnvExt): + def __init__(self, size: int, noise: float = 0.0) -> None: + env = DeepSeaGymEnv(size, noise) + super().__init__(env) class PyBullet(EnvExt): @@ -150,3 +160,41 @@ def __wrap(env_gen: EnvGen, num_workers: int) -> ParallelEnv: return penv return __wrap + + +# Same as bsuite +gym.envs.register( + id="CartPoleSwingUp-v0", + entry_point="rainy.envs.swingup:CartPoleSwingUp", + max_episode_steps=1000, + kwargs=dict(start_position="bottom", allow_noop=True), + reward_threshold=800, +) + +# More difficult +gym.envs.register( + id="CartPoleSwingUp-v1", + entry_point="rainy.envs.swingup:CartPoleSwingUp", + max_episode_steps=1000, + kwargs=dict(start_position="bottom", allow_noop=True, height_threshold=0.9), + reward_threshold=800, +) + + +# No movecost +gym.envs.register( + id="CartPoleSwingUp-v2", + entry_point="rainy.envs.swingup:CartPoleSwingUp", + max_episode_steps=1000, + kwargs=dict(start_position="bottom", allow_noop=False), + reward_threshold=900, +) + +# Arbitary start +gym.envs.register( + id="CartPoleSwingUp-v3", + entry_point="rainy.envs.swingup:CartPoleSwingUp", + max_episode_steps=1000, + kwargs=dict(start_position="arbitary", allow_noop=False), + reward_threshold=900, +) diff --git a/rainy/envs/deepsea.py b/rainy/envs/deepsea.py new file mode 100644 index 0000000..15dcf40 --- /dev/null +++ b/rainy/envs/deepsea.py @@ -0,0 +1,82 @@ +import gym +from gym.utils import seeding +import numpy as np +from typing import List, Optional, Tuple, Union + + +class DeepSea(gym.Env): + SCREEN_SIZE = 400 + + def __init__(self, size: int, noise: float = 0.0) -> None: + self._size = size + self._move_cost = 0.01 / size + self._goal_reward = 1.0 + self._column = 0 + self._row = 0 + self.action_space = gym.spaces.Discrete(2) + low = np.zeros(size ** 2) + high = np.ones(size ** 2) + self.observation_space = gym.spaces.Box(low, high) + self.np_random = None + self.noise = noise + self._viewer = None + + def seed(self, seed: Optional[int] = None) -> List[int]: + self.np_random, seed = seeding.np_random(seed) + return [seed] + + def step(self, action: int) -> Tuple[np.ndarray, float, bool, dict]: + # Remap actions according to column (action_right = go right) + if self.noise == 0.0 or self.noise < self.np_random.uniform(0, 1): + action_right = action == 1 + else: + action_right = action != 1 + + # Compute the reward + reward = 0.0 + if self._column == self._size - 1 and action_right: + reward += self._goal_reward + + # State dynamics + if action_right: # right + self._column = np.clip(self._column + 1, 0, self._size - 1) + reward -= self._move_cost + else: # left + self._column = np.clip(self._column - 1, 0, self._size - 1) + + # Compute the observation + self._row += 1 + if self._row == self._size: + observation = self._get_observation(self._row - 1, self._column) + return observation, reward, True, {} + else: + observation = self._get_observation(self._row, self._column) + return observation, reward, False, {} + + def reset(self) -> np.ndarray: + self._reset_next_step = False + self._column = 0 + self._row = 0 + return self._get_observation(self._row, self._column) + + def render(self, mode: str = "human") -> Union[np.ndarray, bool]: + player_size = self.SCREEN_SIZE / self._size + if self._viewer is None: + from gym.envs.classic_control import rendering + + self._viewer = rendering.Viewer(self.SCREEN_SIZE, self.SCREEN_SIZE) + self.player_trans = rendering.Transform() + v = np.array([(0.0, 0.0), (2.0, 0.0), (1.5, -1.0), (0.5, -1.0)]) + player = rendering.make_polygon(v * player_size / 2) + player.set_color(1.0, 0.0, 0.0) + player.add_attr(self.player_trans) + self._viewer.add_geom(player) + self.player_trans.set_translation( + player_size * self._column, self.SCREEN_SIZE - player_size * self._row, + ) + return self.viewer.render(return_rgb_array=mode == "rgb_array") + + def _get_observation(self, row: int, column: int) -> np.ndarray: + observation = np.zeros(shape=(self._size, self._size), dtype=np.float32) + observation[row, column] = 1 + return observation.flatten() diff --git a/rainy/envs/ext.py b/rainy/envs/ext.py index aebe58d..44725b7 100644 --- a/rainy/envs/ext.py +++ b/rainy/envs/ext.py @@ -37,6 +37,11 @@ def random_action(self) -> Action: def is_discrete(self) -> bool: return isinstance(self.action_space, spaces.Discrete) + def __repr__(self) -> str: + return "EnvSpec(state_dim: {} action_space: {})".format( + self.state_dim, self.action_space + ) + class EnvExt(gym.Env, Generic[Action, State]): def __init__(self, env: gym.Env) -> None: diff --git a/rainy/envs/swingup.py b/rainy/envs/swingup.py new file mode 100644 index 0000000..d187f3d --- /dev/null +++ b/rainy/envs/swingup.py @@ -0,0 +1,110 @@ +from gym import spaces, logger +from gym.envs.classic_control import CartPoleEnv +import numpy as np + +F32_MAX = np.finfo(np.float32).max + + +class CartPoleSwingUp(CartPoleEnv): + START_POSITIONS = ["arbitary", "bottom"] + ACT_TO_FORCE = [-1.0, 1.0, 0.0] + + def __init__( + self, + start_position="arbitary", + height_threshold=0.5, + theta_dot_threshold=1.0, + x_reward_threshold=1.0, + # This is 2.4 in the original CartPole + x_threshold=3.0, + # Aloow 'No operation for action' + allow_noop=False, + move_cost=0.1, + ): + super().__init__() + self.x_threshold = x_threshold + self.start_position = self.START_POSITIONS.index(start_position) + self._height_threshold = height_threshold + self._theta_dot_threshold = theta_dot_threshold + self._x_reward_threshold = x_reward_threshold + self._move_cost = move_cost + if allow_noop: + self.action_space = spaces.Discrete(3) + self.allow_noop = allow_noop + high = np.array([1.0, F32_MAX, 1.0, 1.0, F32_MAX]) + self.observation_space = spaces.Box(-high, high, dtype=np.float32) + + def step(self, action): + """ + action: int + """ + if not self.action_space.contains(action): + raise ValueError(f"Invalid action: {action}") + force = self.force_mag * self.ACT_TO_FORCE[action] + if self.allow_noop and action != 2: + move_cost = self._move_cost + else: + move_cost = 0.0 + state = self.state + x, x_dot, theta, theta_dot = state + costheta, sintheta = np.cos(theta), np.sin(theta) + temp = ( + force + self.polemass_length * theta_dot * theta_dot * sintheta + ) / self.total_mass + thetaacc = (self.gravity * sintheta - costheta * temp) / ( + self.length + * (4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass) + ) + xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass + if self.kinematics_integrator == "euler": + x = x + self.tau * x_dot + x_dot = x_dot + self.tau * xacc + theta = theta + self.tau * theta_dot + theta_dot = theta_dot + self.tau * thetaacc + else: # semi-implicit euler + x_dot = x_dot + self.tau * xacc + x = x + self.tau * x_dot + theta_dot = theta_dot + self.tau * thetaacc + theta = theta + self.tau * theta_dot + + self.state = x, x_dot, theta, theta_dot + done = bool(x < -self.x_threshold or x > self.x_threshold) + + def _reward(): + is_upright = np.cos(theta) > self._height_threshold + is_upright &= np.abs(theta_dot) < self._theta_dot_threshold + is_upright &= np.abs(x) < self._x_reward_threshold + return 1.0 if is_upright else 0.0 - move_cost + + if not done: + reward = _reward() + elif self.steps_beyond_done is None: + # Pole just fell! + self.steps_beyond_done = 0 + reward = _reward() + else: + if self.steps_beyond_done == 0: + logger.warn("You are calling 'step()' after the episode ending.") + self.steps_beyond_done += 1 + reward = 0.0 + + return self._obs(), reward, done, {} + + def reset(self): + self.state = self.np_random.uniform(-0.05, 0.05, size=(4,)) + if self.start_position == 0: + self.state[2] = self.np_random.uniform(-np.pi, np.pi) + else: + self.state[2] += np.pi + self.steps_beyond_done = None + return self._obs() + + def _obs(self): + x, x_dot, theta, theta_dot = self.state + obs = np.zeros(5, dtype=np.float32) + obs[0] = x / self.x_threshold + obs[1] = x_dot / self.x_threshold + obs[2] = np.sin(theta) + obs[3] = np.cos(theta) + obs[4] = theta_dot + return obs diff --git a/rainy/envs/testing.py b/rainy/envs/testing.py index 81d38a0..1200d97 100644 --- a/rainy/envs/testing.py +++ b/rainy/envs/testing.py @@ -1,4 +1,4 @@ -"""Dummy environment for testing +"""Dummy environment and utitlities for testing """ from enum import Enum from gym.spaces import Discrete diff --git a/rainy/lib/explore.py b/rainy/lib/explore.py index 672b3b8..852c195 100644 --- a/rainy/lib/explore.py +++ b/rainy/lib/explore.py @@ -51,7 +51,7 @@ def __call__(self) -> float: class Explorer(ABC): def select_action(self, state: Array, qfunc: DiscreteQFunction) -> LongTensor: - return self.select_from_value(qfunc.q_values(state).detach()) + return self.select_from_value(qfunc.q_value(state).detach()) @abstractmethod def select_from_value(self, value: Tensor) -> LongTensor: @@ -84,11 +84,11 @@ def __init__(self, epsilon: float, cooler: Cooler) -> None: def select_from_value(self, value: Tensor) -> LongTensor: old_eps = self.epsilon self.epsilon = self.cooler() - out_shape, action_dim = value.shape[:-1], value.size(-1) - greedy = value.argmax(-1).view(-1).cpu() - random = torch.randint(action_dim, value.shape[:-1]).view(-1) - res = torch.where(torch.zeros(out_shape).view(-1) < old_eps, random, greedy) - return res.reshape(out_shape).to(value.device) # type: ignore + action_dim = value.size(-1) + greedy = value.argmax(-1).cpu() + random = torch.randint_like(greedy, 0, action_dim) + random_pos = torch.empty(greedy.shape).uniform_() < old_eps + return torch.where(random_pos, random, greedy) def add_noise(self, action: Tensor) -> Tensor: raise NotImplementedError("We can't use EpsGreedy with continuous action") diff --git a/rainy/net/actor_critic.py b/rainy/net/actor_critic.py index d43f04a..49bbfb7 100644 --- a/rainy/net/actor_critic.py +++ b/rainy/net/actor_critic.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod import numpy as np from torch import nn, Tensor -from typing import Callable, List, Optional, Sequence, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple from .block import DQNConv, FcBody, ResNetBody, LinearHead, NetworkBlock from .init import Initializer, orthogonal from .policy import CategoricalDist, Policy, PolicyDist from .prelude import NetFn from .recurrent import DummyRnn, RnnBlock, RnnState -from ..prelude import Array +from ..prelude import ArrayLike from ..utils import Device @@ -27,7 +27,7 @@ def is_recurrent(self) -> bool: @abstractmethod def policy( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tuple[Policy, RnnState]: @@ -36,7 +36,7 @@ def policy( @abstractmethod def value( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tensor: @@ -45,7 +45,7 @@ def value( @abstractmethod def forward( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tuple[Policy, Tensor, RnnState]: @@ -96,7 +96,7 @@ def recurrent_body(self) -> RnnBlock: def _features( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tuple[Tensor, RnnState]: @@ -108,7 +108,7 @@ def _features( def policy( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tuple[Policy, RnnState]: @@ -117,7 +117,7 @@ def policy( def value( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tensor: @@ -126,7 +126,7 @@ def value( def forward( self, - states: Union[Array, Tensor], + states: ArrayLike, rnns: Optional[RnnState] = None, masks: Optional[Tensor] = None, ) -> Tuple[Policy, Tensor, RnnState]: diff --git a/rainy/net/block.py b/rainy/net/block.py index a120879..7988bc9 100644 --- a/rainy/net/block.py +++ b/rainy/net/block.py @@ -1,6 +1,6 @@ """Defines some reusable NN layers, called 'Block' """ -from abc import ABC, abstractmethod +from abc import ABC import numpy as np from torch import nn, Tensor from typing import List, Sequence, Tuple @@ -10,6 +10,7 @@ class NetworkBlock(nn.Module, ABC): """Defines a NN block which returns 1-dimension Tensor """ + input_dim: Sequence[int] output_dim: int @@ -28,7 +29,7 @@ def __init__( ) -> None: super().__init__() self.fc: nn.Linear = init(nn.Linear(input_dim, output_dim)) # type: ignore - self.input_dim = (input_dim, ) + self.input_dim = (input_dim,) self.output_dim = output_dim def forward(self, x: Tensor) -> Tensor: @@ -171,7 +172,7 @@ def __init__( init: Initializer = Initializer(), ) -> None: super().__init__() - self.input_dim = (input_dim, ) + self.input_dim = (input_dim,) self.output_dim = units[-1] dims = [input_dim] + units self.layers = init.make_list( diff --git a/rainy/net/bootstrap.py b/rainy/net/bootstrap.py new file mode 100644 index 0000000..6148c1c --- /dev/null +++ b/rainy/net/bootstrap.py @@ -0,0 +1,160 @@ +from abc import abstractmethod +from copy import deepcopy +import numpy as np +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from typing import List, Sequence +from .block import FcBody, LinearHead, NetworkBlock +from .init import Initializer, xavier_uniform +from .prelude import NetFn +from .value import DiscreteQFunction, DiscreteQValueNet +from ..utils import Device +from ..prelude import Array, ArrayLike + + +class BootstrappedQFunction(DiscreteQFunction): + device: Device + + @abstractmethod + def forward(self, states: ArrayLike) -> Tensor: + pass + + @abstractmethod + def q_i_s(self, index: int, states: ArrayLike) -> Tensor: + pass + + def q_value(self, state: Array) -> Tensor: + values = self(state) + return values.mean(dim=1) + + def q_s_a(self, states: ArrayLike, actions: ArrayLike) -> Tensor: + qs = self(self.device.tensor(states)) + act = self.device.tensor(actions, dtype=torch.long) + action_mask = F.one_hot(act, num_classes=qs.size(-1)).float() + return torch.einsum("bka,ba->bk", qs, action_mask) + + +class SeparatedBootQValueNet(BootstrappedQFunction, nn.Module): + def __init__(self, q_nets: List[DiscreteQFunction]): + super().__init__() + self.q_nets = nn.ModuleList(q_nets) + self.device = q_nets[0].device + + def forward(self, x: ArrayLike) -> Tensor: + return torch.stack([q(x) for q in self.q_nets], dim=1) + + def q_i_s(self, index: int, states: ArrayLike) -> Tensor: + return self.q_nets[index](states) + + @property + def state_dim(self) -> Sequence[int]: + return self.q_nets[0].state_dim + + @property + def action_dim(self) -> int: + return self.q_nets[0].action_dim + + +class SharedBootQValueNet(DiscreteQFunction, nn.Module): + """INCOMPLETE + """ + + def __init__( + self, body: NetworkBlock, heads: List[NetworkBlock], device: Device = Device() + ) -> None: + if body.output_dim != np.prod(head[0].input_dim): + raise ValueError("body output and head input must have a same dimention") + super().__init__() + self.body = body + self.head = nn.ModuleList(head) + self.device = device + self.to(self.device.unwrapped) + + def forward(self, index: int, x: ArrayLike) -> Tensor: + raise NotImplementedError() + + @property + def state_dim(self) -> Sequence[int]: + return self.body.input_dim + + @property + def action_dim(self) -> int: + return self.head[0].output_dim + + +class PriorQValueNet(DiscreteQFunction, nn.Module): + """State -> [Value..] + """ + + def __init__( + self, + body: NetworkBlock, + head: NetworkBlock, + prior_scale: float = 1.0, + device: Device = Device(), + init: Initializer = Initializer(), + ) -> None: + if body.output_dim != np.prod(head.input_dim): + raise ValueError("body output and head input must have a same dimention") + super().__init__() + self.model = nn.Sequential(body, head) + self.prior = init(deepcopy(self.model)) + self.device = device + self.prior_scale = prior_scale + self.to(self.device.unwrapped) + + def q_value(self, state: Array, nostack: bool = False) -> Tensor: + if nostack: + return self.forward(state) + else: + return self.forward(np.stack([state])) + + def forward(self, x: ArrayLike) -> Tensor: + x = self.device.tensor(x) + raw = self.model(x) + with torch.no_grad(): + prior = self.prior(x) + return raw.add_(prior.mul_(self.prior_scale)) + + @property + def state_dim(self) -> Sequence[int]: + return self.model[0].input_dim + + @property + def action_dim(self) -> int: + return self.model[1].output_dim + + +def fc_separated(n_ensembles: int, *args, **kwargs) -> NetFn: + def _net( + state_dim: Sequence[int], action_dim: int, device: Device + ) -> SeparatedBootQValueNet: + q_nets = [] + for _ in range(n_ensembles): + body = FcBody(state_dim[0], *args, **kwargs) + head = LinearHead(body.output_dim, action_dim) + q_nets.append(DiscreteQValueNet(body, head, device=device)) + return SeparatedBootQValueNet(q_nets) + + return _net + + +def rpf_fc_separated( + n_ensembles: int, + prior_scale: float = 1.0, + init: Initializer = Initializer(weight_init=xavier_uniform()), + **kwargs +) -> NetFn: + def _net( + state_dim: Sequence[int], action_dim: int, device: Device + ) -> SeparatedBootQValueNet: + q_nets = [] + for _ in range(n_ensembles): + body = FcBody(state_dim[0], init=init, **kwargs) + head = LinearHead(body.output_dim, action_dim, init=init) + prior_model = PriorQValueNet(body, head, prior_scale, init=init) + q_nets.append(prior_model) + return SeparatedBootQValueNet(q_nets) + + return _net diff --git a/rainy/net/deterministic.py b/rainy/net/deterministic.py index 466e90d..1d36c78 100644 --- a/rainy/net/deterministic.py +++ b/rainy/net/deterministic.py @@ -3,18 +3,18 @@ from rainy.utils import Device import torch from torch import nn, Tensor -from typing import Iterable, List, Sequence, Tuple, Union +from typing import Iterable, List, Sequence, Tuple from .block import FcBody, LinearHead, NetworkBlock from .init import Initializer, kaiming_uniform from .misc import SoftUpdate from .prelude import NetFn from .value import ContinuousQFunction -from ..prelude import Array +from ..prelude import ArrayLike class DeterministicPolicyNet(ABC): @abstractmethod - def action(self, state: Union[Array, Tensor]) -> Tensor: + def action(self, state: ArrayLike) -> Tensor: pass @@ -51,21 +51,17 @@ def actor_params(self) -> Iterable[Tensor]: def critic_params(self) -> Iterable[Tensor]: return self.critic.parameters() - def action(self, states: Union[Array, Tensor]) -> Tensor: + def action(self, states: ArrayLike) -> Tensor: s = self.device.tensor(states) return self.actor(s).mul(self.action_coef) - def q_value( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] - ) -> Tensor: + def q_value(self, states: ArrayLike, action: ArrayLike) -> Tensor: s = self.device.tensor(states) a = self.device.tensor(action) sa = torch.cat((s, a), dim=1) return self.critic(sa) - def forward( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] - ) -> Tuple[Tensor, Tensor]: + def forward(self, states: ArrayLike, action: ArrayLike) -> Tuple[Tensor, Tensor]: s = self.device.tensor(states) a = self.device.tensor(action) sa = torch.cat((s, a), dim=1) @@ -99,9 +95,7 @@ def __init__( def critic_params(self) -> Iterable[Tensor]: return chain(self.critic.parameters(), self.critic2.parameters()) - def q_values( - self, states: Union[Array, Tensor], action: Union[Array, Tensor], - ) -> Tuple[Tensor, Tensor]: + def q_value(self, states: ArrayLike, action: ArrayLike) -> Tensor: s = self.device.tensor(states) a = self.device.tensor(action) sa = torch.cat((s, a), dim=1) diff --git a/rainy/net/init.py b/rainy/net/init.py index b735ca8..f6eae9f 100644 --- a/rainy/net/init.py +++ b/rainy/net/init.py @@ -29,6 +29,10 @@ def kaiming_uniform(**kwargs) -> InitFn: return partial(nn.init.kaiming_uniform_, **kwargs) +def xavier_uniform(**kwargs) -> InitFn: + return partial(nn.init.xavier_uniform_, **kwargs) + + def fanin_uniform() -> InitFn: def _fanin_uniform(w: Tensor) -> Tensor: if w.dim() <= 2: diff --git a/rainy/net/option_critic.py b/rainy/net/option_critic.py index 56dac01..e631825 100644 --- a/rainy/net/option_critic.py +++ b/rainy/net/option_critic.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from torch import nn, Tensor -from typing import Callable, Sequence, Tuple, Union +from typing import Callable, Sequence, Tuple from .actor_critic import policy_init from .block import DQNConv, FcBody, LinearHead, NetworkBlock from .policy import BernoulliDist, BernoulliPolicy, CategoricalDist, Policy, PolicyDist from .prelude import NetFn -from ..prelude import Array +from ..prelude import ArrayLike from ..utils import Device @@ -17,13 +17,11 @@ class OptionCriticNet(nn.Module, ABC): state_dim: Sequence[int] @abstractmethod - def opt_q(self, states: Union[Array, Tensor]) -> Tensor: + def opt_q(self, states: ArrayLike) -> Tensor: pass @abstractmethod - def forward( - self, states: Union[Array, Tensor] - ) -> Tuple[Policy, Tensor, BernoulliPolicy]: + def forward(self, states: ArrayLike) -> Tuple[Policy, Tensor, BernoulliPolicy]: pass @@ -53,13 +51,11 @@ def __init__( self.state_dim = self.body.input_dim self.to(device.unwrapped) - def opt_q(self, states: Union[Array, Tensor]) -> Tensor: + def opt_q(self, states: ArrayLike) -> Tensor: feature = self.body(self.device.tensor(states)) return self.optq_head(feature) - def forward( - self, states: Union[Array, Tensor] - ) -> Tuple[Policy, Tensor, BernoulliPolicy]: + def forward(self, states: ArrayLike) -> Tuple[Policy, Tensor, BernoulliPolicy]: feature = self.body(self.device.tensor(states)) policy = self.actor_head(feature).view(-1, self.num_options, self.action_dim) opt_q = self.optq_head(feature) diff --git a/rainy/net/sac.py b/rainy/net/sac.py index ca78b51..e7a2958 100644 --- a/rainy/net/sac.py +++ b/rainy/net/sac.py @@ -2,14 +2,14 @@ import itertools import torch from torch import nn, Tensor -from typing import Iterable, List, Sequence, Tuple, Union +from typing import Iterable, List, Sequence, Tuple from .block import FcBody, LinearHead, NetworkBlock from .init import Initializer, fanin_uniform, constant from .misc import SoftUpdate from .policy import Policy, PolicyDist, TanhGaussianDist from .prelude import NetFn from .value import ContinuousQFunction -from ..prelude import Array, Self +from ..prelude import ArrayLike, Self from ..utils import Device @@ -24,9 +24,7 @@ def soft_update(self, other: Self, coef: float) -> None: SoftUpdate.soft_update(self.critic1, other.critic1, coef) # type: ignore SoftUpdate.soft_update(self.critic2, other.critic2, coef) # type: ignore - def q_values( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] - ) -> Tuple[Tensor, Tensor]: + def q_values(self, states: ArrayLike, action: ArrayLike) -> Tuple[Tensor, Tensor]: sa = torch.cat((self.device.tensor(states), self.device.tensor(action)), dim=1) return self.critic1(sa), self.critic2(sa) @@ -58,15 +56,11 @@ def __init__( self.device = device self.to(device.unwrapped) - def q_value( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] - ) -> Tensor: + def q_value(self, states: ArrayLike, action: ArrayLike) -> Tensor: sa = torch.cat((self.device.tensor(states), self.device.tensor(action)), dim=1) return self.critic1(sa) - def q_values( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] - ) -> Tuple[Tensor, Tensor]: + def q_values(self, states: ArrayLike, action: ArrayLike) -> Tuple[Tensor, Tensor]: sa = torch.cat((self.device.tensor(states), self.device.tensor(action)), dim=1) return self.critic1(sa), self.critic2(sa) @@ -75,7 +69,7 @@ def get_target(self) -> SACTarget: copy.deepcopy(self.critic1), copy.deepcopy(self.critic2), self.device ) - def policy(self, states: Union[Array, Tensor]) -> Policy: + def policy(self, states: ArrayLike) -> Policy: st = self.device.tensor(states) if st.dim() == 1: st = st.view(1, -1) @@ -89,7 +83,7 @@ def critic_params(self) -> Iterable[Tensor]: return itertools.chain(self.critic1.parameters(), self.critic2.parameters()) def forward( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] + self, states: ArrayLike, action: ArrayLike ) -> Tuple[Tensor, Tensor, Policy]: s, a = self.device.tensor(states), self.device.tensor(action) sa = torch.cat((s, a), dim=1) diff --git a/rainy/net/value.py b/rainy/net/value.py index c5740e6..c15ff6b 100644 --- a/rainy/net/value.py +++ b/rainy/net/value.py @@ -1,24 +1,22 @@ from abc import ABC, abstractmethod import numpy as np from torch import nn, Tensor -from typing import Sequence, Tuple, Union +from typing import Sequence, Tuple from .block import DQNConv, FcBody, LinearHead, NetworkBlock from .prelude import NetFn from ..utils import Device -from ..prelude import Array +from ..prelude import Array, ArrayLike class ContinuousQFunction(ABC): @abstractmethod - def q_value( - self, states: Union[Array, Tensor], action: Union[Array, Tensor] - ) -> Tensor: + def q_value(self, states: ArrayLike, action: ArrayLike) -> Tensor: pass class DiscreteQFunction(ABC): @abstractmethod - def q_values(self, state: Array, nostack: bool = False) -> Tensor: + def q_value(self, state: Array, nostack: bool = False) -> Tensor: pass @property @@ -37,26 +35,29 @@ class DiscreteQValueNet(DiscreteQFunction, nn.Module): """ def __init__( - self, body: NetworkBlock, head: NetworkBlock, device: Device = Device() + self, + body: NetworkBlock, + head: NetworkBlock, + device: Device = Device(), + do_not_use_data_parallel: bool = False, ) -> None: - assert body.output_dim == np.prod( - head.input_dim - ), "body output and head input must have a same dimention" + if body.output_dim != np.prod(head.input_dim): + raise ValueError("body output and head input must have a same dimention") super().__init__() self.head = head self.body = body - if device.is_multi_gpu(): + if not do_not_use_data_parallel and device.is_multi_gpu(): self.body = device.data_parallel(body) # type: ignore self.device = device self.to(self.device.unwrapped) - def q_values(self, state: Array, nostack: bool = False) -> Tensor: + def q_value(self, state: Array, nostack: bool = False) -> Tensor: if nostack: return self.forward(state) else: return self.forward(np.stack([state])) - def forward(self, x: Union[Array, Tensor]) -> Tensor: + def forward(self, x: ArrayLike) -> Tensor: x = self.device.tensor(x) x = self.body(x) x = self.head(x) diff --git a/rainy/replay/__init__.py b/rainy/replay/__init__.py index 72b115b..0f5b059 100644 --- a/rainy/replay/__init__.py +++ b/rainy/replay/__init__.py @@ -1,3 +1,3 @@ from .array_deque import ArrayDeque from .base import ReplayBuffer -from .uniform import DQNReplayFeed, UniformReplayBuffer +from .uniform import BootDQNReplayFeed, DQNReplayFeed, UniformReplayBuffer diff --git a/rainy/replay/base.py b/rainy/replay/base.py index 07469fe..c04e4db 100644 --- a/rainy/replay/base.py +++ b/rainy/replay/base.py @@ -7,6 +7,7 @@ class ReplayBuffer(ABC, Generic[ReplayFeed]): def __init__(self, feed: Type[ReplayFeed]) -> None: self.feed = feed + self.allow_overlap = False @abstractmethod def append(self, *args) -> None: diff --git a/rainy/replay/uniform.py b/rainy/replay/uniform.py index 90e8c25..43cece8 100644 --- a/rainy/replay/uniform.py +++ b/rainy/replay/uniform.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Callable, Generic, List, NamedTuple, Tuple, Type from .array_deque import ArrayDeque from .base import ReplayFeed, ReplayBuffer @@ -18,7 +19,11 @@ def append(self, *args) -> None: self.buf.pop_front() def sample(self, batch_size: int) -> List[ReplayFeed]: - return [self.buf[idx] for idx in sample_indices(len(self.buf), batch_size)] + if self.allow_overlap: + indices = np.random.randint(len(self.buf), size=batch_size) + else: + indices = sample_indices(len(self.buf), batch_size) + return [self.buf[idx] for idx in indices] def __len__(self): return len(self.buf) @@ -31,9 +36,9 @@ class DQNReplayFeed(NamedTuple, Generic[State], metaclass=GenericNamedMeta): next_state: State done: bool - def to_ndarray( + def to_array( self, wrap: Callable[[State], Array] - ) -> Tuple[Array, int, float, Array, bool]: + ) -> Tuple[Array[float], int, float, Array[float], bool]: return ( wrap(self.state), self.action, @@ -41,3 +46,24 @@ def to_ndarray( wrap(self.next_state), self.done, ) + + +class BootDQNReplayFeed(NamedTuple, Generic[State], metaclass=GenericNamedMeta): + state: State + action: int + reward: float + next_state: State + done: bool + ensemble_mask: Array[bool] + + def to_array( + self, wrap: Callable[[State], Array] + ) -> Tuple[Array[float], int, float, Array[float], bool, Array[bool]]: + return ( + wrap(self.state), + self.action, + self.reward, + wrap(self.next_state), + self.done, + self.ensemble_mask, + ) diff --git a/rainy/run.py b/rainy/run.py index 7beaba9..7c6989a 100644 --- a/rainy/run.py +++ b/rainy/run.py @@ -17,6 +17,7 @@ def train_agent( saveid_start: int = 0, save_file_name: str = SAVE_FILE_DEFAULT, action_file_name: str = ACTION_FILE_DEFAULT, + eval_render: bool = False, ) -> None: ag.logger.summary_setting( "train", @@ -46,9 +47,9 @@ def log_eval() -> None: fname = logdir.joinpath( "{}-{}{}".format(action_file.stem, episodes, action_file.suffix) ) - res = _eval_impl(ag, fname) + res = _eval_impl(ag, fname, render=eval_render) else: - res = _eval_impl(ag, None) + res = _eval_impl(ag, None, render=eval_render) rewards, length = _reward_and_length(res) ag.logger.submit( "eval", diff --git a/rainy/utils/cli.py b/rainy/utils/cli.py index e4b8fac..b9a56d8 100644 --- a/rainy/utils/cli.py +++ b/rainy/utils/cli.py @@ -1,5 +1,5 @@ import click -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional from ..agents import Agent from ..config import Config @@ -9,45 +9,33 @@ @click.group() -@click.option( - "--gpu", required=False, type=int, help="How many gpus you allow the script to use" -) @click.option( "--envname", type=str, default=None, help="Name of environment passed to config_gen" ) +@click.option("--max-steps", type=int, default=None, help="Max steps of the training") @click.option( "--seed", type=int, default=None, help="Random seed set before training. Left for backward comaptibility", ) -@click.option( - "--override", type=str, default="", help="Override string(see README for detail)" -) @click.pass_context def rainy_cli( ctx: click.Context, - gpu: Tuple[int], envname: Optional[str], + max_steps: Optional[int], seed: Optional[int], - override: str, + **kwargs, ) -> None: - ctx.obj["gpu"] = gpu cfg_gen = ctx.obj["config_gen"] - ctx.obj["config"] = cfg_gen(envname) if envname is not None else cfg_gen() + if envname is not None: + kwargs["envname"] = envname + if max_steps is not None: + kwargs["max_steps"] = max_steps + ctx.obj["config"] = cfg_gen(**kwargs) ctx.obj["config"].seed = seed - ctx.obj["override"] = override ctx.obj["envname"] = "Default" if envname is None else envname - if len(override) > 0: - import builtins - - try: - exec(override, builtins.__dict__, {"config": ctx.obj["config"]}) - except Exception as e: - print( - "!!! Your override string '{}' contains an error !!!".format(override) - ) - raise e + ctx.obj["kwargs"] = kwargs @rainy_cli.command(help="Train agents") @@ -59,20 +47,25 @@ def rainy_cli( help="Comment that would be wrote to fingerprint.txt", ) @click.option("--prefix", type=str, default="", help="Prefix of the log directory") -def train(ctx: click.Context, comment: Optional[str], prefix: str) -> None: +@click.option( + "--eval-render", is_flag=True, help="Render the environment when evaluating" +) +def train( + ctx: click.Context, comment: Optional[str], prefix: str, eval_render: bool = True +) -> None: c = ctx.obj["config"] script_path = ctx.obj["script_path"] if script_path is not None: fingerprint = dict( comment="" if comment is None else comment, envname=ctx.obj["envname"], - override=ctx.obj["override"], + kwargs=ctx.obj["kwargs"], ) c.logger.set_dir_from_script_path( script_path, prefix=prefix, fingerprint=fingerprint ) ag = ctx.obj["make_agent"](c) - run.train_agent(ag) + run.train_agent(ag, eval_render=eval_render) if mpi.IS_MPI_ROOT: print( "random play: {}, trained: {}".format( @@ -174,14 +167,21 @@ def ipython(ctx: click.Context, logdir: Optional[str]) -> None: _open_ipython(logdir) +def _add_options(options: List[click.Command] = []) -> click.Group: + for option in options: + rainy_cli.params.append(option) + + def run_cli( config_gen: Callable[..., Config], agent_gen: Callable[[Config], Agent], script_path: Optional[str] = None, + options: List[click.Command] = [], ) -> None: obj = { "config_gen": config_gen, "make_agent": agent_gen, "script_path": script_path, } + _add_options(options) rainy_cli(obj=obj)