Skip to content

Commit

Permalink
Adding support for lifting the veil script:
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 505818590
  • Loading branch information
Johan Obando Ceron authored and joshgreaves committed Feb 27, 2023
1 parent 6762221 commit 156a760
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
16 changes: 12 additions & 4 deletions dopamine/jax/agents/full_rainbow/full_rainbow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def get_q_values(model, states, rng):
return model(states, key=rng).q_values


@functools.partial(jax.jit, static_argnums=(0, 3, 12, 13, 14))
@functools.partial(jax.jit, static_argnames=('network_def', 'optimizer',
'cumulative_gamma', 'double_dqn',
'distributional', 'mse_loss'))
def train(network_def, online_params, target_params, optimizer, optimizer_state,
states, actions, next_states, rewards, terminals, loss_weights,
support, cumulative_gamma, double_dqn, distributional, rng):
support, cumulative_gamma, double_dqn, distributional, mse_loss, rng):
"""Run a training step."""

# Split the current rng into 2 for updating the rng after this call
Expand Down Expand Up @@ -112,7 +114,9 @@ def q_online(state, key):
q_values = get_q_values(q_online, states, rng)
q_values = jnp.squeeze(q_values)
replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions)
loss = jax.vmap(losses.huber_loss)(target, replay_chosen_q)

loss = losses.mse_loss if mse_loss else losses.huber_loss
loss = jax.vmap(loss)(target, replay_chosen_q)

mean_loss = jnp.mean(loss_multipliers * loss)
return mean_loss, loss
Expand Down Expand Up @@ -177,6 +181,7 @@ def __init__(self,
dueling=True,
double_dqn=True,
distributional=True,
mse_loss=False,
num_updates_per_train_step=1,
network=networks.FullRainbowNetwork,
num_atoms=51,
Expand All @@ -195,6 +200,7 @@ def __init__(self,
dueling: bool, Whether to use dueling network architecture or not.
double_dqn: bool, Whether to use Double DQN or not.
distributional: bool, whether to use distributional RL or not.
mse_loss: bool, mse loss function.
num_updates_per_train_step: int, Number of gradient updates every training
step. Defaults to 1.
network: flax.linen Module, neural network used by the agent initialized
Expand All @@ -221,6 +227,7 @@ def __init__(self,
logging.info('\t noisy_networks: %s', noisy)
logging.info('\t dueling_dqn: %s', dueling)
logging.info('\t distributional: %s', distributional)
logging.info('\t mse_loss: %d', mse_loss)
logging.info('\t num_atoms: %d', num_atoms)
logging.info('\t replay_scheme: %s', replay_scheme)
logging.info('\t num_updates_per_train_step: %d',
Expand All @@ -235,6 +242,7 @@ def __init__(self,
self._noisy = noisy
self._dueling = dueling
self._distributional = distributional
self._mse_loss = mse_loss
self._num_updates_per_train_step = num_updates_per_train_step

super().__init__(
Expand Down Expand Up @@ -295,7 +303,7 @@ def _training_step_update(self):
self.replay_elements['action'], next_states,
self.replay_elements['reward'], self.replay_elements['terminal'],
loss_weights, self._support, self.cumulative_gamma, self._double_dqn,
self._distributional, self._rng)
self._distributional, self._mse_loss, self._rng)

if self._replay_scheme == 'prioritized':
# Rainbow and prioritized replay are parametrized by an exponent
Expand Down
5 changes: 4 additions & 1 deletion dopamine/labs/atari_100k/atari_100k_rainbow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Atari100kRainbowAgent(full_rainbow_agent.JaxFullRainbowAgent):
def __init__(self,
num_actions,
data_augmentation=False,
mse_loss=False,
summary_writer=None,
network=networks.FullRainbowNetwork,
seed=None):
Expand All @@ -95,6 +96,7 @@ def __init__(self,
Args:
num_actions: int, number of actions the agent can take at any state.
data_augmentation: bool, whether to use data augmentation.
mse_loss: bool, mse loss function.
summary_writer: SummaryWriter object, for outputting training statistics.
network: flax.linen Module, neural network used by the agent initialized
by shape in _create_network below. See
Expand All @@ -109,6 +111,7 @@ def __init__(self,
seed=seed)
logging.info('\t data_augmentation: %s', data_augmentation)
self._data_augmentation = data_augmentation
self._mse_loss = mse_loss
logging.info('\t data_augmentation: %s', data_augmentation)
# Preprocessing during training and evaluation can be possibly different,
# for example, when using data augmentation during training.
Expand Down Expand Up @@ -141,7 +144,7 @@ def _training_step_update(self):
self.replay_elements['action'], next_states,
self.replay_elements['reward'], self.replay_elements['terminal'],
loss_weights, self._support, self.cumulative_gamma, self._double_dqn,
self._distributional, self._rng)
self._distributional, self._mse_loss, self._rng)

if self._replay_scheme == 'prioritized':
self._replay.set_priority(self.replay_elements['indices'],
Expand Down

0 comments on commit 156a760

Please sign in to comment.