Skip to content

Commit

Permalink
Added support for specifying custom models under all existing presets. (
Browse files Browse the repository at this point in the history
  • Loading branch information
michalgregor committed Aug 13, 2020
1 parent 74feb3e commit 1d2cf9f
Show file tree
Hide file tree
Showing 23 changed files with 153 additions and 44 deletions.
13 changes: 10 additions & 3 deletions all/presets/atari/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def a2c(
# Batch settings
n_envs=16,
n_steps=5,
# Model construction
feature_model_constructor=nature_features,
value_model_constructor=nature_value_head,
policy_model_constructor=nature_policy_head
):
"""
A2C Atari preset.
Expand All @@ -39,14 +43,17 @@ def a2c(
value_loss_scaling (float): Coefficient for the value function loss.
n_envs (int): Number of parallel environments.
n_steps (int): Length of each rollout.
feature_model_constructor (function): The function used to construct the neural feature model.
value_model_constructor (function): The function used to construct the neural value model.
policy_model_constructor (function): The function used to construct the neural policy model.
"""
def _a2c(envs, writer=DummyWriter()):
env = envs[0]
final_anneal_step = last_frame / (n_steps * n_envs * 4)

value_model = nature_value_head().to(device)
policy_model = nature_policy_head(env).to(device)
feature_model = nature_features().to(device)
value_model = value_model_constructor().to(device)
policy_model = policy_model_constructor(env).to(device)
feature_model = feature_model_constructor().to(device)

feature_optimizer = Adam(feature_model.parameters(), lr=lr, eps=eps)
value_optimizer = Adam(value_model.parameters(), lr=lr, eps=eps)
Expand Down
5 changes: 4 additions & 1 deletion all/presets/atari/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def c51(
atoms=51,
v_min=-10,
v_max=10,
# Model construction
model_constructor=nature_c51
):
"""
C51 Atari preset.
Expand All @@ -53,13 +55,14 @@ def c51(
the distributional value function.
v_min (int): The expected return corresponding to the smallest atom.
v_max (int): The expected return correspodning to the larget atom.
model_constructor (function): The function used to construct the neural model.
"""
def _c51(env, writer=DummyWriter()):
action_repeat = 4
last_timestep = last_frame / action_repeat
last_update = (last_timestep - replay_start_size) / update_frequency

model = nature_c51(env, atoms=atoms).to(device)
model = model_constructor(env, atoms=atoms).to(device)
optimizer = Adam(
model.parameters(),
lr=lr,
Expand Down
5 changes: 4 additions & 1 deletion all/presets/atari/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def ddqn(
# Prioritized replay settings
alpha=0.5,
beta=0.5,
# Model construction
model_constructor=nature_ddqn
):
"""
Dueling Double DQN with Prioritized Experience Replay (PER).
Expand All @@ -55,14 +57,15 @@ def ddqn(
(0 = no prioritization, 1 = full prioritization)
beta (float): The strength of the importance sampling correction for prioritized experience replay.
(0 = no correction, 1 = full correction)
model_constructor (function): The function used to construct the neural model.
"""
def _ddqn(env, writer=DummyWriter()):
action_repeat = 4
last_timestep = last_frame / action_repeat
last_update = (last_timestep - replay_start_size) / update_frequency
final_exploration_step = final_exploration_frame / action_repeat

model = nature_ddqn(env).to(device)
model = model_constructor(env).to(device)
optimizer = Adam(
model.parameters(),
lr=lr,
Expand Down
5 changes: 4 additions & 1 deletion all/presets/atari/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def dqn(
initial_exploration=1.,
final_exploration=0.01,
final_exploration_frame=4000000,
# Model construction
model_constructor=nature_dqn
):
"""
DQN Atari preset.
Expand All @@ -49,14 +51,15 @@ def dqn(
decayed until final_exploration_frame.
final_exploration (int): Final probability of choosing a random action.
final_exploration_frame (int): The frame where the exploration decay stops.
model_constructor (function): The function used to construct the neural model.
"""
def _dqn(env, writer=DummyWriter()):
action_repeat = 4
last_timestep = last_frame / action_repeat
last_update = (last_timestep - replay_start_size) / update_frequency
final_exploration_step = final_exploration_frame / action_repeat

model = nature_dqn(env).to(device)
model = model_constructor(env).to(device)

optimizer = Adam(
model.parameters(),
Expand Down
13 changes: 10 additions & 3 deletions all/presets/atari/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def ppo(
n_steps=128,
# GAE settings
lam=0.95,
# Model construction
feature_model_constructor=nature_features,
value_model_constructor=nature_value_head,
policy_model_constructor=nature_policy_head
):
"""
PPO Atari preset.
Expand All @@ -51,6 +55,9 @@ def ppo(
n_envs (int): Number of parallel actors.
n_steps (int): Length of each rollout.
lam (float): The Generalized Advantage Estimate (GAE) decay parameter.
feature_model_constructor (function): The function used to construct the neural feature model.
value_model_constructor (function): The function used to construct the neural value model.
policy_model_constructor (function): The function used to construct the neural policy model.
"""
def _ppo(envs, writer=DummyWriter()):
env = envs[0]
Expand All @@ -60,9 +67,9 @@ def _ppo(envs, writer=DummyWriter()):
# with n_envs and 4 frames per step
final_anneal_step = last_frame * epochs * minibatches / (n_steps * n_envs * 4)

value_model = nature_value_head().to(device)
policy_model = nature_policy_head(env).to(device)
feature_model = nature_features().to(device)
value_model = value_model_constructor().to(device)
policy_model = policy_model_constructor(env).to(device)
feature_model = feature_model_constructor().to(device)

feature_optimizer = Adam(
feature_model.parameters(), lr=lr, eps=eps
Expand Down
5 changes: 4 additions & 1 deletion all/presets/atari/rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def rainbow(
v_max=10,
# Noisy Nets
sigma=0.5,
# Model construction
model_constructor=nature_rainbow
):
"""
Rainbow Atari Preset.
Expand Down Expand Up @@ -66,13 +68,14 @@ def rainbow(
v_min (int): The expected return corresponding to the smallest atom.
v_max (int): The expected return correspodning to the larget atom.
sigma (float): Initial noisy network noise.
model_constructor (function): The function used to construct the neural model.
"""
def _rainbow(env, writer=DummyWriter()):
action_repeat = 4
last_timestep = last_frame / action_repeat
last_update = (last_timestep - replay_start_size) / update_frequency

model = nature_rainbow(env, atoms=atoms, sigma=sigma).to(device)
model = model_constructor(env, atoms=atoms, sigma=sigma).to(device)
optimizer = Adam(model.parameters(), lr=lr, eps=eps)
q = QDist(
model,
Expand Down
13 changes: 10 additions & 3 deletions all/presets/atari/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def vac(
value_loss_scaling=0.25,
# Parallel actors
n_envs=16,
# Model construction
feature_model_constructor=nature_features,
value_model_constructor=nature_value_head,
policy_model_constructor=nature_policy_head
):
"""
Vanilla Actor-Critic Atari preset.
Expand All @@ -35,11 +39,14 @@ def vac(
Set to 0 to disable.
value_loss_scaling (float): Coefficient for the value function loss.
n_envs (int): Number of parallel environments.
feature_model_constructor (function): The function used to construct the neural feature model.
value_model_constructor (function): The function used to construct the neural value model.
policy_model_constructor (function): The function used to construct the neural policy model.
"""
def _vac(envs, writer=DummyWriter()):
value_model = nature_value_head().to(device)
policy_model = nature_policy_head(envs[0]).to(device)
feature_model = nature_features().to(device)
value_model = value_model_constructor().to(device)
policy_model = policy_model_constructor(envs[0]).to(device)
feature_model = feature_model_constructor().to(device)

value_optimizer = Adam(value_model.parameters(), lr=lr_v, eps=eps)
policy_optimizer = Adam(policy_model.parameters(), lr=lr_pi, eps=eps)
Expand Down
13 changes: 10 additions & 3 deletions all/presets/atari/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def vpg(
clip_grad=0.5,
value_loss_scaling=0.25,
min_batch_size=1000,
# Model construction
feature_model_constructor=nature_features,
value_model_constructor=nature_value_head,
policy_model_constructor=nature_policy_head
):
"""
Vanilla Policy Gradient Atari preset.
Expand All @@ -35,13 +39,16 @@ def vpg(
value_loss_scaling (float): Coefficient for the value function loss.
min_batch_size (int): Continue running complete episodes until at least this many
states have been seen since the last update.
feature_model_constructor (function): The function used to construct the neural feature model.
value_model_constructor (function): The function used to construct the neural value model.
policy_model_constructor (function): The function used to construct the neural policy model.
"""
final_anneal_step = last_frame / (min_batch_size * 4)

def _vpg_atari(env, writer=DummyWriter()):
value_model = nature_value_head().to(device)
policy_model = nature_policy_head(env).to(device)
feature_model = nature_features().to(device)
value_model = value_model_constructor().to(device)
policy_model = policy_model_constructor(env).to(device)
feature_model = feature_model_constructor().to(device)

feature_optimizer = Adam(feature_model.parameters(), lr=lr, eps=eps)
value_optimizer = Adam(value_model.parameters(), lr=lr, eps=eps)
Expand Down
5 changes: 4 additions & 1 deletion all/presets/atari/vqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def vqn(
final_exploration_frame=1000000,
# Parallel actors
n_envs=64,
# Model construction
model_constructor=nature_ddqn
):
"""
Vanilla Q-Network Atari preset.
Expand All @@ -34,13 +36,14 @@ def vqn(
final_exploration (int): Final probability of choosing a random action.
final_exploration_frame (int): The frame where the exploration decay stops.
n_envs (int): Number of parallel environments.
model_constructor (function): The function used to construct the neural model.
"""
def _vqn(envs, writer=DummyWriter()):
action_repeat = 4
final_exploration_timestep = final_exploration_frame / action_repeat

env = envs[0]
model = nature_ddqn(env).to(device)
model = model_constructor(env).to(device)
optimizer = Adam(model.parameters(), lr=lr, eps=eps)
q = QNetwork(
model,
Expand Down
5 changes: 4 additions & 1 deletion all/presets/atari/vsarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def vsarsa(
initial_exploration=1.,
# Parallel actors
n_envs=64,
# Model construction
model_constructor=nature_ddqn
):
"""
Vanilla SARSA Atari preset.
Expand All @@ -34,13 +36,14 @@ def vsarsa(
final_exploration (int): Final probability of choosing a random action.
final_exploration_frame (int): The frame where the exploration decay stops.
n_envs (int): Number of parallel environments.
model_constructor (function): The function used to construct the neural model.
"""
def _vsarsa(envs, writer=DummyWriter()):
action_repeat = 4
final_exploration_timestep = final_exploration_frame / action_repeat

env = envs[0]
model = nature_ddqn(env).to(device)
model = model_constructor(env).to(device)
optimizer = Adam(model.parameters(), lr=lr, eps=eps)
q = QNetwork(
model,
Expand Down
13 changes: 10 additions & 3 deletions all/presets/classic_control/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def a2c(
# Batch settings
n_envs=4,
n_steps=32,
# Model construction
feature_model_constructor=fc_relu_features,
value_model_constructor=fc_value_head,
policy_model_constructor=fc_policy_head
):
"""
A2C classic control preset.
Expand All @@ -30,12 +34,15 @@ def a2c(
entropy_loss_scaling (float): Coefficient for the entropy term in the total loss.
n_envs (int): Number of parallel environments.
n_steps (int): Length of each rollout.
feature_model_constructor (function): The function used to construct the neural feature model.
value_model_constructor (function): The function used to construct the neural value model.
policy_model_constructor (function): The function used to construct the neural policy model.
"""
def _a2c(envs, writer=DummyWriter()):
env = envs[0]
feature_model = fc_relu_features(env).to(device)
value_model = fc_value_head().to(device)
policy_model = fc_policy_head(env).to(device)
feature_model = feature_model_constructor(env).to(device)
value_model = value_model_constructor().to(device)
policy_model = policy_model_constructor(env).to(device)

feature_optimizer = Adam(feature_model.parameters(), lr=lr)
value_optimizer = Adam(value_model.parameters(), lr=lr)
Expand Down
7 changes: 5 additions & 2 deletions all/presets/classic_control/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def c51(
# Distributional RL
atoms=101,
v_min=-100,
v_max=100
v_max=100,
# Model construction
model_constructor=fc_relu_dist_q
):
"""
C51 classic control preset.
Expand All @@ -47,9 +49,10 @@ def c51(
the distributional value function.
v_min (int): The expected return corresponding to the smallest atom.
v_max (int): The expected return correspodning to the larget atom.
model_constructor (function): The function used to construct the neural model.
"""
def _c51(env, writer=DummyWriter()):
model = fc_relu_dist_q(env, atoms=atoms).to(device)
model = model_constructor(env, atoms=atoms).to(device)
optimizer = Adam(model.parameters(), lr=lr)
q = QDist(
model,
Expand Down
5 changes: 4 additions & 1 deletion all/presets/classic_control/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def ddqn(
# Prioritized replay settings
alpha=0.2,
beta=0.6,
# Model construction
model_constructor=dueling_fc_relu_q
):
"""
Dueling Double DQN with Prioritized Experience Replay (PER).
Expand All @@ -50,9 +52,10 @@ def ddqn(
(0 = no prioritization, 1 = full prioritization)
beta (float): The strength of the importance sampling correction for prioritized experience replay.
(0 = no correction, 1 = full correction)
model_constructor (function): The function used to construct the neural model.
"""
def _ddqn(env, writer=DummyWriter()):
model = dueling_fc_relu_q(env).to(device)
model = model_constructor(env).to(device)
optimizer = Adam(model.parameters(), lr=lr)
q = QNetwork(
model,
Expand Down
5 changes: 4 additions & 1 deletion all/presets/classic_control/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def dqn(
initial_exploration=1.,
final_exploration=0.,
final_exploration_frame=10000,
# Model construction
model_constructor=fc_relu_q
):
"""
DQN classic control preset.
Expand All @@ -42,9 +44,10 @@ def dqn(
decayed until final_exploration_frame.
final_exploration (int): Final probability of choosing a random action.
final_exploration_frame (int): The frame where the exploration decay stops.
model_constructor (function): The function used to construct the neural model.
"""
def _dqn(env, writer=DummyWriter()):
model = fc_relu_q(env).to(device)
model = model_constructor(env).to(device)
optimizer = Adam(model.parameters(), lr=lr)
q = QNetwork(
model,
Expand Down

0 comments on commit 1d2cf9f

Please sign in to comment.