Skip to content
This repository has been archived by the owner on Jan 27, 2023. It is now read-only.

Bootstrapped DQN #48

Merged
merged 15 commits into from
Nov 26, 2019
2 changes: 1 addition & 1 deletion examples/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.optim import RMSprop


def config(game: str = "Breakout") -> Config:
def config(envname: str = "Breakout") -> Config:
c = Config()
c.set_env(lambda: Atari(game, frame_stack=False))
c.set_optimizer(lambda params: RMSprop(params, lr=7e-4, alpha=0.99, eps=1e-5))
Expand Down
13 changes: 9 additions & 4 deletions examples/a2c_cart_pole.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import rainy
from rainy.utils.cli import run_cli
from rainy.envs import MultiProcEnv
from rainy.envs import ClassicControl, MultiProcEnv
from torch.optim import Adam


def config() -> rainy.Config:
def config(envname: str = "CartPole-v0", rnn: bool = False) -> rainy.Config:
c = rainy.Config()
c.set_env(lambda: ClassicControl(envname))
c.max_steps = int(1e6)
c.nworkers = 12
c.nsteps = 5
Expand All @@ -18,9 +19,13 @@ def config() -> rainy.Config:
c.eval_deterministic = True
c.eval_freq = c.max_steps // 10
c.entropy_weight = 0.001
# c.set_net_fn('actor-critic', rainy.net.actor_critic.fc_shared(rnn=rainy.net.GruBlock))
if rnn:
c.set_net_fn(
"actor-critic", rainy.net.actor_critic.fc_shared(rnn=rainy.net.GruBlock)
)
return c


if __name__ == "__main__":
run_cli(config, rainy.agents.A2CAgent, script_path=os.path.realpath(__file__))
options = [click.Option(["--rnn"], is_flag=True)]
run_cli(config, rainy.agents.A2CAgent, os.path.realpath(__file__), options)
2 changes: 1 addition & 1 deletion examples/acktr_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
}


def config(game: str = "Breakout") -> Config:
def config(envname: str = "Breakout") -> Config:
c = Config()
c.set_env(lambda: Atari(game, frame_stack=False))
c.set_optimizer(kfac.default_sgd(eta_max=0.2))
Expand Down
5 changes: 3 additions & 2 deletions examples/acktr_cart_pole.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rainy import Config
from rainy.agents import ACKTRAgent
import rainy.utils.cli as cli
from rainy.envs import MultiProcEnv
from rainy.envs import ClassicControl, MultiProcEnv
from rainy.lib import kfac


Expand All @@ -13,8 +13,9 @@
}


def config() -> Config:
def config(envname: str = "CartPole-v0") -> Config:
c = Config()
c.set_env(lambda: ClassicControl(envname))
c.max_steps = int(4e5)
c.nworkers = 12
c.nsteps = 20
Expand Down
2 changes: 1 addition & 1 deletion examples/aoc_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.optim import RMSprop


def config(game: str = "Breakout") -> Config:
def config(envname: str = "Breakout") -> Config:
c = Config()
c.set_env(lambda: Atari(game, frame_stack=False))
c.set_optimizer(lambda params: RMSprop(params, lr=7e-4, alpha=0.99, eps=1e-5))
Expand Down
5 changes: 3 additions & 2 deletions examples/aoc_cart_pole.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import rainy
from rainy.utils.cli import run_cli
from rainy.envs import MultiProcEnv
from rainy.envs import ClassicControl, MultiProcEnv
from torch import optim


def config() -> rainy.Config:
def config(envname: str = "CartPole-v0") -> rainy.Config:
c = rainy.Config()
c.set_env(lambda: ClassicControl(envname))
c.max_steps = int(4e5)
c.nworkers = 12
c.nsteps = 5
Expand Down
42 changes: 42 additions & 0 deletions examples/bootdqn_cart_pole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import click
import os
from rainy import Config
from rainy.agents import EpisodicBootDQNAgent
from rainy.envs import ClassicControl
from rainy.lib import explore
from rainy.net import bootstrap
from rainy.replay import BootDQNReplayFeed, UniformReplayBuffer
import rainy.utils.cli as cli
from torch import optim


def config(
envname: str = "CartPole-v0",
max_steps: int = 1000000,
rpf: bool = False,
replay_prob: float = 0.5,
prior_scale: float = 1.0,
) -> Config:
c = Config()
c.set_optimizer(lambda params: optim.Adam(params))
c.set_explorer(lambda: explore.Greedy())
c.set_explorer(lambda: explore.Greedy(), key="eval")
c.set_env(lambda: ClassicControl(envname))
c.max_steps = max_steps
c.episode_log_freq = 100
c.replay_prob = replay_prob
if rpf:
c.set_net_fn("bootdqn", bootstrap.rpf_fc_separated(10, prior_scale=prior_scale))
c.set_replay_buffer(
lambda capacity: UniformReplayBuffer(BootDQNReplayFeed, capacity=capacity)
)
return c


if __name__ == "__main__":
options = [
click.Option(["--rpf"], is_flag=True),
click.Option(["--replay-prob", "-RP"], type=float, default=0.5),
click.Option(["--prior-scale", "-PS"], type=float, default=1.0),
]
cli.run_cli(config, EpisodicBootDQNAgent, os.path.realpath(__file__), options)
42 changes: 42 additions & 0 deletions examples/bootdqn_deepsea.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import click
import os
from rainy import Config
from rainy.agents import EpisodicBootDQNAgent
from rainy.envs import DeepSea
from rainy.lib import explore
from rainy.net import bootstrap
from rainy.replay import BootDQNReplayFeed, UniformReplayBuffer
import rainy.utils.cli as cli
from torch import optim


def config(
max_steps: int = 100000,
size: int = 20,
rpf: bool = False,
replay_prob: float = 0.5,
prior_scale: float = 1.0,
) -> Config:
c = Config()
c.set_optimizer(lambda params: optim.Adam(params))
c.set_explorer(lambda: explore.Greedy())
c.set_explorer(lambda: explore.Greedy(), key="eval")
c.set_env(lambda: DeepSea(size))
c.max_steps = max_steps
c.episode_log_freq = 100
c.replay_prob = replay_prob
if rpf:
c.set_net_fn("bootdqn", bootstrap.rpf_fc_separated(10, prior_scale=prior_scale))
c.set_replay_buffer(
lambda capacity: UniformReplayBuffer(BootDQNReplayFeed, capacity=capacity)
)
return c


if __name__ == "__main__":
options = [
click.Option(["--rpf"], is_flag=True),
click.Option(["--replay-prob", "-RP"], type=float, default=0.5),
click.Option(["--prior-scale", "-PS"], type=float, default=1.0),
]
cli.run_cli(config, EpisodicBootDQNAgent, os.path.realpath(__file__), options)
4 changes: 2 additions & 2 deletions examples/ddqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from torch.optim import RMSprop


def config(game: str = "Breakout") -> Config:
def config(envname: str = "Breakout") -> Config:
c = Config()
c.set_env(lambda: Atari(game))
c.set_optimizer(
lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
)
c.set_explorer(lambda: EpsGreedy(1.0, LinearCooler(1.0, 0.1, int(1e6))))
c.set_net_fn("value", net.value.dqn_conv())
c.set_net_fn("dqn", net.value.dqn_conv())
c.replay_size = int(1e6)
c.batch_size = 32
c.train_start = 50000
Expand Down
4 changes: 3 additions & 1 deletion examples/ddqn_cart_pole.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from rainy import Config
from rainy.agents import DoubleDQNAgent
from rainy.envs import ClassicControl
import rainy.utils.cli as cli


def config() -> Config:
def config(envname: str = "CartPole-v0") -> Config:
c = Config()
c.set_env(lambda: ClassicControl(envname))
c.max_steps = 100000
return c

Expand Down
4 changes: 2 additions & 2 deletions examples/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from torch.optim import RMSprop


def config(game: str = "Breakout") -> Config:
def config(envname: str = "Breakout") -> Config:
c = Config()
c.set_env(lambda: Atari(game))
c.set_optimizer(
lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
)
c.set_explorer(lambda: EpsGreedy(1.0, LinearCooler(1.0, 0.1, int(1e6))))
c.set_net_fn("value", net.value.dqn_conv())
c.set_net_fn("dqn", net.value.dqn_conv())
c.replay_size = int(1e6)
c.replay_batch_size = 32
c.train_start = 50000
Expand Down
8 changes: 5 additions & 3 deletions examples/dqn_cart_pole.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from rainy import Config
from rainy.agents import DQNAgent
from rainy.envs import ClassicControl
import rainy.utils.cli as cli


def config() -> Config:
def config(envname: str = "CartPole-v0", max_steps: int = 100000) -> Config:
c = Config()
c.max_steps = 100000
c.eval_freq = 100
c.set_env(lambda: ClassicControl(envname))
c.max_steps = max_steps
c.episode_log_freq = 100
return c


Expand Down
17 changes: 17 additions & 0 deletions examples/episodic_dqn_cart_pole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
from rainy import Config
from rainy.agents import EpisodicDQNAgent
from rainy.envs import ClassicControl
import rainy.utils.cli as cli


def config(envname: str = "CartPole-v0") -> Config:
c = Config()
c.set_env(lambda: ClassicControl(envname))
c.max_steps = 100000
c.episode_log_freq = 100
return c


if __name__ == "__main__":
cli.run_cli(config, EpisodicDQNAgent, script_path=os.path.realpath(__file__))
2 changes: 1 addition & 1 deletion examples/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.optim import Adam


def config(game: str = "Breakout") -> Config:
def config(envname: str = "Breakout") -> Config:
c = Config()
c.set_env(lambda: Atari(game, frame_stack=False))
# c.set_net_fn('actor-critic', net.actor_critic.ac_conv(rnn=net.GruBlock))
Expand Down
5 changes: 3 additions & 2 deletions examples/ppo_cart_pole.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import rainy
from rainy.utils.cli import run_cli
from rainy.envs import MultiProcEnv
from rainy.envs import ClassicControl, MultiProcEnv
from torch.optim import Adam


def config() -> rainy.Config:
def config(envname: str = "CartPole-v0") -> rainy.Config:
c = rainy.Config()
c.set_env(lambda: ClassicControl(envname))
c.max_steps = int(1e5)
c.nworkers = 8
c.nsteps = 32
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_flicker_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rainy.utils.cli as cli


def config(game: str = "Breakout") -> rainy.Config:
def config(envname: str = "Breakout") -> rainy.Config:
c = ppo_atari.config(game)
c.set_env(lambda: Atari(game, flicker_frame=True, frame_stack=False))
c.set_parallel_env(atari_parallel(frame_stack=False))
Expand Down
3 changes: 2 additions & 1 deletion rainy/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from .acktr import ACKTRAgent
from .aoc import AOCAgent
from .base import Agent, EpisodeResult, NStepParallelAgent, OneStepAgent
from .bootdqn import BootDQNAgent, EpisodicBootDQNAgent
from .ddpg import DDPGAgent
from .dqn import DQNAgent, DoubleDQNAgent
from .dqn import DQNAgent, DoubleDQNAgent, EpisodicDQNAgent
from .ppo import PPOAgent
from .sac import SACAgent
from .td3 import TD3Agent
Loading