Skip to content
This repository has been archived by the owner on Jan 27, 2023. It is now read-only.

Commit

Permalink
Refactor DQN and BootDQN
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 24, 2019
1 parent b03a3b5 commit 8089e75
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 81 deletions.
4 changes: 2 additions & 2 deletions examples/bootdqn_cart_pole.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import click
import os
from rainy import Config
from rainy.agents import BootDQNAgent
from rainy.agents import EpisodicBootDQNAgent
from rainy.envs import ClassicControl
from rainy.lib import explore
from rainy.net import bootstrap
Expand Down Expand Up @@ -38,4 +38,4 @@ def config(
click.Option(["--replay-prob", "-RP"], type=float, default=0.5),
click.Option(["--prior-scale", "-PS"], type=float, default=1.0),
]
cli.run_cli(config, BootDQNAgent, os.path.realpath(__file__), options)
cli.run_cli(config, EpisodicBootDQNAgent, os.path.realpath(__file__), options)
2 changes: 1 addition & 1 deletion rainy/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .acktr import ACKTRAgent
from .aoc import AOCAgent
from .base import Agent, EpisodeResult, NStepParallelAgent, OneStepAgent
from .bootdqn import BootDQNAgent
from .bootdqn import BootDQNAgent, EpisodicBootDQNAgent
from .ddpg import DDPGAgent
from .dqn import DQNAgent, DoubleDQNAgent, EpisodicDQNAgent
from .ppo import PPOAgent
Expand Down
62 changes: 50 additions & 12 deletions rainy/agents/bootdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Deep Exploration via Bootstrapped DQN(https://arxiv.org/abs/1602.04621)
- Randomized Prior Functions for Deep Reinforcement Learning(https://arxiv.org/abs/1806.03335)
"""
import copy
import numpy as np
import torch
from torch import Tensor
Expand All @@ -14,10 +15,7 @@
from ..replay import BootDQNReplayFeed


class BootDQNAgent(OneStepAgent):
"""It's the 2nd version of BootDQN described in RPF paper.
"""

class EpisodicBootDQNAgent(OneStepAgent):
SAVED_MEMBERS = "net", "policy", "total_steps"

def __init__(self, config: Config) -> None:
Expand All @@ -30,6 +28,7 @@ def __init__(self, config: Config) -> None:
self.eval_policy = config.explorer(key="eval")
self.replay = config.replay_buffer()
self.replay.allow_overlap = True
self.active_head = 0
if self.replay.feed is not BootDQNReplayFeed:
raise RuntimeError("BootDQNAgent needs BootDQNReplayFeed")

Expand All @@ -41,16 +40,21 @@ def eval_action(self, state: Array) -> Action:
return self.eval_policy.select_action(state, self.net).item() # type: ignore

def step(self, state: State) -> Tuple[State, float, bool, dict]:
action = self.policy.select_action(self.env.extract(state), self.net).item()
with torch.no_grad():
qs = self.net.q_i_s(self.active_head, self.env.extract(state)).detach()
action = self.policy.select_from_value(qs).item()
next_state, reward, done, info = self.env.step(action)
n_ens = self.config.num_ensembles
mask = np.random.uniform(0, 1, n_ens) < self.config.replay_prob
self.replay.append(state, action, reward, next_state, done, mask)
self._append_to_replay(state, action, reward, next_state, done)
if done:
self._train()
self.net.active_head = np.random.randint(n_ens)
self.active_head = np.random.randint(self.config.num_ensembles)
return next_state, reward, done, info

def _append_to_replay(self, *transition) -> None:
n_ens = self.config.num_ensembles
mask = np.random.uniform(0, 1, n_ens) < self.config.replay_prob
self.replay.append(*transition, mask)

@torch.no_grad()
def _q_next(self, next_states: Array) -> Tensor:
return self.net(next_states).max(axis=-1)[0]
Expand All @@ -64,7 +68,41 @@ def _train(self) -> None:
r = self.tensor(rewards).view(-1, 1)
q_target = r + q_next * self.tensor(1.0 - done).mul_(gamma).view(-1, 1)
q_current = self.net.q_s_a(states, actions)
mse = F.mse_loss(q_current, q_target, reduction="none")
masked_loss = mse.mul_(self.tensor(mask)).mean()
self._backward(masked_loss.mean(), self.optimizer, self.net.parameters())
loss = F.mse_loss(q_current, q_target, reduction="none")
masked_loss = loss.masked_select(self.tensor(mask, dtype=torch.bool)).mean()
self._backward(masked_loss, self.optimizer, self.net.parameters())
self.network_log(q_value=q_current.mean().item(), value_loss=masked_loss.item())


class BootDQNAgent(EpisodicBootDQNAgent):
SAVED_MEMBERS = "net", "policy", "total_steps", "target_net"

def __init__(self, config: Config) -> None:
super().__init__(config)
self.target_net = copy.deepcopy(self.net)
self.replay.allow_overlap = False

def step(self, state: State) -> Tuple[State, float, bool, dict]:
train_started = self.total_steps > self.config.train_start
if train_started:
with torch.no_grad():
qs = self.net.q_i_s(self.active_head, self.env.extract(state)).detach()
action = self.policy.select_from_value(qs).item()
else:
action = self.env.spec.random_action()
next_state, reward, done, info = self.env.step(action)
self._append_to_replay(state, action, reward, next_state, done)
if done:
self.active_head = np.random.randint(self.config.num_ensembles)
if train_started:
self._train()
return next_state, reward, done, info

@torch.no_grad()
def _q_next(self, next_states: Array) -> Tensor:
return self.net(next_states).max(axis=-1)[0]

def _train(self) -> None:
super()._train
if (self.update_steps + 1) % self.config.sync_freq == 0:
self.target_net.load_state_dict(self.net.state_dict())
65 changes: 33 additions & 32 deletions rainy/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
from ..replay import DQNReplayFeed


class DQNAgent(OneStepAgent):
SAVED_MEMBERS = "net", "target_net", "policy", "total_steps"
class EpisodicDQNAgent(OneStepAgent):
"""A DQN variant, which has no target network and
updates the target only once per episode.
"""
SAVED_MEMBERS = "net", "policy", "total_steps"

def __init__(self, config: Config) -> None:
super().__init__(config)
if not self.env.spec.is_discrete():
raise RuntimeError("DQN only supports discrete action space.")
self.net = config.net("dqn")
self.target_net = deepcopy(self.net)
self.optimizer = config.optimizer(self.net.parameters())
self.policy = config.explorer()
self.eval_policy = config.explorer(key="eval")
Expand All @@ -35,20 +37,16 @@ def eval_action(self, state: Array) -> Action:
return self.eval_policy.select_action(state, self.net).item() # type: ignore

def step(self, state: State) -> Tuple[State, float, bool, dict]:
train_started = self.total_steps > self.config.train_start
if train_started:
action = self.policy.select_action(self.env.extract(state), self.net).item()
else:
action = self.env.spec.random_action()
action = self.policy.select_action(self.env.extract(state), self.net).item()
next_state, reward, done, info = self.env.step(action)
self.replay.append(state, action, reward, next_state, done)
if train_started:
if done:
self._train()
return next_state, reward, done, info

@torch.no_grad()
def _q_next(self, next_states: Array) -> Tensor:
return self.target_net(next_states).max(axis=-1)[0]
return self.net(next_states).max(axis=-1)[0]

def _train(self) -> None:
obs = self.replay.sample(self.config.replay_batch_size)
Expand All @@ -60,40 +58,43 @@ def _train(self) -> None:
loss = F.mse_loss(q_current, q_target)
self._backward(loss, self.optimizer, self.net.parameters())
self.network_log(q_value=q_current.mean().item(), value_loss=loss.item())
if (self.update_steps + 1) % self.config.sync_freq == 0:
self.target_net.load_state_dict(self.net.state_dict())


class DoubleDQNAgent(DQNAgent):
@torch.no_grad()
def _q_next(self, next_states: Array) -> Tensor:
"""Returns Q values of next_states, supposing torch.no_grad() is called
"""
q_next = self.target_net(next_states)
q_value = self.net.q_value(next_states, nostack=True)
return q_next[self.batch_indices, q_value.argmax(dim=-1)]
class DQNAgent(EpisodicDQNAgent):


class EpisodicDQNAgent(DQNAgent):
"""Same as DQN, but does an update per episode.
"""

SAVED_MEMBERS = "net", "policy", "total_steps"
SAVED_MEMBERS = "net", "policy", "total_steps", "target_net"

def __init__(self, config: Config) -> None:
super().__init__(config)
self.config.sync_freq = self.config.max_steps * 10
self.replay.allow_overlap = True
del self.target_net
self.target_net = deepcopy(self.net)

def step(self, state: State) -> Tuple[State, float, bool, dict]:
action = self.policy.select_action(self.env.extract(state), self.net).item()
train_started = self.total_steps > self.config.train_start
if train_started:
action = self.policy.select_action(self.env.extract(state), self.net).item()
else:
action = self.env.spec.random_action()
next_state, reward, done, info = self.env.step(action)
self.replay.append(state, action, reward, next_state, done)
if done:
if train_started:
self._train()
return next_state, reward, done, info

@torch.no_grad()
def _q_next(self, next_states: Array) -> Tensor:
return self.net(next_states).max(axis=-1)[0]
return self.target_net(next_states).max(axis=-1)[0]

def _train(self):
super()._train()
if (self.update_steps + 1) % self.config.sync_freq == 0:
self.target_net.load_state_dict(self.net.state_dict())


class DoubleDQNAgent(DQNAgent):
@torch.no_grad()
def _q_next(self, next_states: Array) -> Tensor:
"""Returns Q values of next_states, supposing torch.no_grad() is called
"""
q_next = self.target_net(next_states)
q_value = self.net.q_value(next_states, nostack=True)
return q_next[self.batch_indices, q_value.argmax(dim=-1)]
67 changes: 38 additions & 29 deletions rainy/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,6 @@
from ..prelude import Self, State


# Same as bsuite
gym.envs.register(
id="CartPoleSwingUp-v0",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="bottom", allow_noop=True),
reward_threshold=800,
)

# More difficult
gym.envs.register(
id="CartPoleSwingUp-v1",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="bottom", allow_noop=True, height_threshold=0.9),
reward_threshold=800,
)


# No movecost
gym.envs.register(
id="CartPoleSwingUp-v2",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="botttom", allow_noop=False),
reward_threshold=900,
)


class AtariConfig:
STYLES = ["deepmind", "baselines", "dopamine", "rnd"]

Expand Down Expand Up @@ -182,3 +153,41 @@ def __wrap(env_gen: EnvGen, num_workers: int) -> ParallelEnv:
return penv

return __wrap


# Same as bsuite
gym.envs.register(
id="CartPoleSwingUp-v0",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="bottom", allow_noop=True),
reward_threshold=800,
)

# More difficult
gym.envs.register(
id="CartPoleSwingUp-v1",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="bottom", allow_noop=True, height_threshold=0.9),
reward_threshold=800,
)


# No movecost
gym.envs.register(
id="CartPoleSwingUp-v2",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="bottom", allow_noop=False),
reward_threshold=900,
)

# Arbitary start
gym.envs.register(
id="CartPoleSwingUp-v3",
entry_point="rainy.envs.swingup:CartPoleSwingUp",
max_episode_steps=1000,
kwargs=dict(start_position="arbitary", allow_noop=False),
reward_threshold=900,
)
16 changes: 11 additions & 5 deletions rainy/net/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@


class BootstrappedQFunction(DiscreteQFunction):
active_head: int
device: Device

@abstractmethod
def forward(self, states: ArrayLike) -> Tensor:
pass

@abstractmethod
def q_i_s(self, index: int, states: ArrayLike) -> Tensor:
pass

def q_value(self, state: Array) -> Tensor:
values = self(state)
return values.mean(dim=1)

def q_s_a(self, states: ArrayLike, actions: ArrayLike) -> Tensor:
qs = self(self.device.tensor(states))
act = self.device.tensor(actions, dtype=torch.long)
Expand All @@ -32,15 +39,14 @@ class SeparatedBootQValueNet(BootstrappedQFunction, nn.Module):
def __init__(self, q_nets: List[DiscreteQFunction]):
super().__init__()
self.q_nets = nn.ModuleList(q_nets)
self.active_head = 0
self.device = q_nets[0].device

def q_value(self, state: Array) -> Tensor:
return self.q_nets[self.active_head].q_value(state)

def forward(self, x: ArrayLike) -> Tensor:
return torch.stack([q(x) for q in self.q_nets], dim=1)

def q_i_s(self, index: int, states: ArrayLike) -> Tensor:
return self.q_nets[index](states)

@property
def state_dim(self) -> Sequence[int]:
return self.q_nets[0].state_dim
Expand Down

0 comments on commit 8089e75

Please sign in to comment.