From c1b562e2219d69d8d7c05b642eb8203a9ed1b8cc Mon Sep 17 00:00:00 2001 From: Piotr Kozakowski Date: Wed, 29 Apr 2020 03:25:13 -0700 Subject: [PATCH] Enable training the value network with TD returns. PiperOrigin-RevId: 308992492 --- trax/rl/actor_critic.py | 108 ++++++++++++++++++++++++++------- trax/rl/actor_critic_test.py | 11 +++- trax/rl/task.py | 29 ++++++--- trax/shapes.py | 7 +++ trax/supervised/trainer_lib.py | 17 ++++++ 5 files changed, 142 insertions(+), 30 deletions(-) diff --git a/trax/rl/actor_critic.py b/trax/rl/actor_critic.py index 9294de4ca..7800e8c3b 100644 --- a/trax/rl/actor_critic.py +++ b/trax/rl/actor_critic.py @@ -22,6 +22,7 @@ from trax import layers as tl from trax import lr_schedules as lr +from trax import shapes from trax import supervised from trax.math import numpy as jnp from trax.rl import advantages as rl_advantages @@ -49,6 +50,7 @@ def __init__(self, task, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, + scale_value_targets=False, **kwargs): # Arguments of PolicyTrainer come here. """Configures the actor-critic Trainer. @@ -73,6 +75,8 @@ def __init__(self, task, have maximum length set by max_slice_length in **kwargs n_replay_epochs: how many last epochs to take into the replay buffer; only makes sense for off-policy algorithms + scale_value_targets: whether to scale targets for the value function by + 1 / (1 - gamma) **kwargs: arguments for PolicyTrainer super-class """ self._n_shared_layers = n_shared_layers @@ -89,6 +93,14 @@ def __init__(self, task, self._added_policy_slice_length = added_policy_slice_length self._n_replay_epochs = n_replay_epochs + if scale_value_targets: + self._value_network_scale = 1 / (1 - self._task.gamma) + else: + self._value_network_scale = 1 + + self._value_eval_model = value_model(mode='eval') + self._value_eval_model.init(self._value_model_signature) + # Initialize training of the value function. value_output_dir = kwargs.get('output_dir', None) if value_output_dir is not None: @@ -106,13 +118,18 @@ def __init__(self, task, inputs=self._value_inputs, output_dir=value_output_dir, metrics={'value_loss': tl.L2Loss()}) - self._value_eval_model = value_model(mode='eval') - value_batch = next(self.value_batches_stream()) - self._value_eval_model.init(value_batch) # Initialize policy training. super(ActorCriticTrainer, self).__init__(task, **kwargs) + @property + def _value_model_signature(self): + obs_sig = shapes.signature(self._task.observation_space) + target_sig = mask_sig = shapes.ShapeDtype( + shape=(1, 1, 1), + ) + return (obs_sig.replace(shape=(1, 1) + obs_sig.shape), target_sig, mask_sig) + @property def _replay_epochs(self): if self.on_policy: @@ -124,15 +141,39 @@ def _replay_epochs(self): def value_batches_stream(self): """Use the RLTask self._task to create inputs to the value model.""" + max_slice_length = self._max_slice_length + self._added_policy_slice_length for np_trajectory in self._task.trajectory_batch_stream( - self._value_batch_size, max_slice_length=self._max_slice_length, - epochs=self._replay_epochs): + self._value_batch_size, + max_slice_length=max_slice_length, + min_slice_length=(1 + self._added_policy_slice_length), + epochs=self._replay_epochs, + ): + values = self._value_eval_model( + np_trajectory.observations, n_accelerators=1 + ) * self._value_network_scale + values = np.squeeze(values, axis=2) # Remove the singleton depth dim. + + # TODO(pkozakowski): Add some shape assertions and docs. + # Calculate targets based on the advantages over the target network - this + # allows TD learning for value networks. + advantages = self._advantage_estimator( + np_trajectory.rewards, np_trajectory.returns, values, + gamma=self._task.gamma, + n_extra_steps=self._added_policy_slice_length, + ) + length = advantages.shape[1] + values = values[:, :length] + target_returns = values + advantages + # Insert an extra depth dimension, so the target shape is consistent with # the network output shape. yield ( - np_trajectory.observations, # Inputs to the value model. - np_trajectory.returns[:, :, None], # Targets: regress to returns. - np_trajectory.mask[:, :, None], # Mask to zero-out padding. + # Inputs: observations. + np_trajectory.observations[:, :length], + # Targets: computed returns. + target_returns[:, :, None] / self._value_network_scale, + # Mask to zero-out padding. + np_trajectory.mask[:, :length, None], ) def policy_inputs(self, trajectory, values): @@ -159,9 +200,9 @@ def policy_batches_stream(self): epochs=self._replay_epochs, max_slice_length=max_slice_length, include_final_state=False): - value_model = self._value_eval_model - value_model.weights = self._value_trainer.model_weights - values = value_model(np_trajectory.observations, n_accelerators=1) + values = self._value_eval_model( + np_trajectory.observations, n_accelerators=1 + ) * self._value_network_scale values = np.squeeze(values, axis=2) # Remove the singleton depth dim. if len(values.shape) != 2: raise ValueError('Values are expected to have shape ' + @@ -173,6 +214,19 @@ def policy_batches_stream(self): def train_epoch(self): """Trains RL for one epoch.""" + # Copy policy state accumulated during data collection to the trainer. + self._policy_trainer.model_state = self._policy_collect_model.state + + # Copy policy weights and state to value trainer. + if self._n_shared_layers > 0: + _copy_model_weights_and_state( + 0, self._n_shared_layers, self._policy_trainer, self._value_trainer + ) + + # Update the target value network. + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state + n_value_evals = rl_training.remaining_evals( self._value_trainer.step, self._epoch, @@ -183,9 +237,11 @@ def train_epoch(self): self._value_train_steps_per_epoch // self._value_evals_per_epoch, self._value_eval_steps, ) - if self._n_shared_layers > 0: # Copy value weights to policy trainer. - _copy_model_weights(0, self._n_shared_layers, - self._value_trainer, self._policy_trainer) + # Copy value weights and state to policy trainer. + if self._n_shared_layers > 0: + _copy_model_weights_and_state( + 0, self._n_shared_layers, self._value_trainer, self._policy_trainer + ) n_policy_evals = rl_training.remaining_evals( self._policy_trainer.step, self._epoch, @@ -196,31 +252,41 @@ def train_epoch(self): n_policy_evals < self._policy_evals_per_epoch) should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value if should_copy_weights: - _copy_model_weights(0, self._n_shared_layers, - self._value_trainer, self._policy_trainer) + _copy_model_weights_and_state( + 0, self._n_shared_layers, self._value_trainer, self._policy_trainer + ) + + # Update the target value network. + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state for _ in range(n_policy_evals): self._policy_trainer.train_epoch( self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, self._policy_eval_steps, ) - if self._n_shared_layers > 0: # Copy policy weights to value trainer. - _copy_model_weights(0, self._n_shared_layers, - self._policy_trainer, self._value_trainer) def close(self): self._value_trainer.close() super().close() -def _copy_model_weights(start, end, from_trainer, to_trainer, # pylint: disable=invalid-name - copy_optimizer_slots=True): +def _copy_model_weights_and_state( # pylint: disable=invalid-name + start, end, from_trainer, to_trainer, copy_optimizer_slots=False +): """Copy model weights[start:end] from from_trainer to to_trainer.""" from_weights = from_trainer.model_weights to_weights = to_trainer.model_weights shared_weights = from_weights[start:end] to_weights[start:end] = shared_weights to_trainer.model_weights = to_weights + + from_state = from_trainer.model_state + to_state = to_trainer.model_state + shared_state = from_state[start:end] + to_state[start:end] = shared_state + to_trainer.model_state = to_state + if copy_optimizer_slots: # TODO(lukaszkaiser): make a nicer API in Trainer to support this. # Currently we use the hack below. Note [0] since that's the model w/o loss. diff --git a/trax/rl/actor_critic_test.py b/trax/rl/actor_critic_test.py index eb6ed38e4..708074316 100644 --- a/trax/rl/actor_critic_test.py +++ b/trax/rl/actor_critic_test.py @@ -26,6 +26,7 @@ from trax import optimizers as opt from trax import test_utils from trax.rl import actor_critic +from trax.rl import advantages from trax.rl import task as rl_task @@ -161,7 +162,10 @@ def test_awrtrainer_cartpole(self): policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, - collect_per_epoch=10) + collect_per_epoch=10, + advantage_estimator=advantages.monte_carlo, + advantage_normalization=False, + ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 50.0) @@ -189,7 +193,10 @@ def test_awrtrainer_cartpole_shared(self): policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, - collect_per_epoch=10) + collect_per_epoch=10, + advantage_estimator=advantages.monte_carlo, + advantage_normalization=False, + ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 50.0) diff --git a/trax/rl/task.py b/trax/rl/task.py index 6109e54d1..c5b84e1b7 100644 --- a/trax/rl/task.py +++ b/trax/rl/task.py @@ -52,7 +52,7 @@ def __init__(self, observation, action=None, reward=None, log_prob=None): 'log_probs', 'rewards', 'returns', - 'mask' + 'mask', ]) @@ -289,7 +289,7 @@ def __init__(self, env=gin.REQUIRED, initial_trajectories=1, gamma=0.99, initial_trajectories = [ # Whatever we gather here is intended to be removed # in PolicyTrainer. Here we just gather some example inputs. - self.play(_random_policy(self.action_space), max_steps=2) + self.play(_random_policy(self.action_space)) ] if isinstance(initial_trajectories, list): @@ -328,11 +328,18 @@ def action_space(self): else: return self._env.action_space - def observation_shape(self): + @property + def observation_space(self): + """Returns the env's observation space in a Gym interface.""" if self._dm_suite: - return self._env.observation_spec().shape + return gym.spaces.Box( + shape=self._env.observation_spec().shape, + dtype=self._env.observation_spec().dtype, + low=float('-inf'), + high=float('+inf'), + ) else: - return self._env.observation_space.shape + return self._env.observation_space @property def trajectories(self): @@ -480,6 +487,7 @@ def n_slices(t): def trajectory_batch_stream(self, batch_size, epochs=None, max_slice_length=None, + min_slice_length=None, include_final_state=False, sample_trajectories_uniformly=False): """Return a stream of trajectory batches from the specified epochs. @@ -491,6 +499,7 @@ def trajectory_batch_stream(self, batch_size, epochs=None, batch_size: the size of the batches to return epochs: a list of epochs to use; we use all epochs if None max_slice_length: maximum length of the slices of trajectories to return + min_slice_length: minimum length of the slices of trajectories to return include_final_state: whether to include slices with the final state of the trajectory which may have no action and reward sample_trajectories_uniformly: whether to sample trajectories uniformly, @@ -498,7 +507,8 @@ def trajectory_batch_stream(self, batch_size, epochs=None, Yields: batches of trajectory slices sampled uniformly from all slices of length - upto max_slice_length in all specified epochs + at least min_slice_length and up to max_slice_length in all specified + epochs """ def pad(tensor_list): max_len = max([t.shape[0] for t in tensor_list]) @@ -513,6 +523,11 @@ def pad(tensor_list): for t in self.trajectory_stream( epochs, max_slice_length, include_final_state, sample_trajectories_uniformly): + # TODO(pkozakowski): Instead sample the trajectories out of those with + # the minimum length. + if min_slice_length is not None and len(t) < min_slice_length: + continue + cur_batch.append(t) if len(cur_batch) == batch_size: obs, act, logp, rew, ret, _ = zip(*[t.to_np(self._timestep_to_np) @@ -520,7 +535,7 @@ def pad(tensor_list): # Where act, logp, rew and ret will usually have the following shape: # [batch_size, trajectory_length-1], which we call [B, L-1]. # Observations are more complex and will usuall be [B, L] + S where S - # is the shape of the observation space (self.observation_shape). + # is the shape of the observation space (self.observation_space.shape). yield TrajectoryNp( pad(obs), pad(act), pad(logp), pad(rew), pad(ret), pad([np.ones(a.shape[:1]) for a in act])) diff --git a/trax/shapes.py b/trax/shapes.py index d1f8c3825..269f3787c 100644 --- a/trax/shapes.py +++ b/trax/shapes.py @@ -71,6 +71,13 @@ def __len__(self): def as_tuple(self): return self.shape, self.dtype + def replace(self, **kwargs): + """Creates a copy of the object with some parameters replaced.""" + return type(self)( + shape=kwargs.pop('shape', self.shape), + dtype=kwargs.pop('dtype', self.dtype), + ) + def signature(obj): """Returns a `ShapeDtype` signature for the given `obj`. diff --git a/trax/supervised/trainer_lib.py b/trax/supervised/trainer_lib.py index 866beaa0b..aaaa55054 100644 --- a/trax/supervised/trainer_lib.py +++ b/trax/supervised/trainer_lib.py @@ -210,6 +210,23 @@ def model_weights(self, weights): new_weights = [new_model_weights] + list(self._opt_state.weights[1:]) self._opt_state = self._opt_state._replace(weights=new_weights) + @property + def model_state(self): + # Currently we need to pick [0] as we ignore loss state (empty). + state = self._model_state[0] + if self.n_devices > 1: + unreplicate = lambda x: x[0] + state = math.nested_map(unreplicate, state) + return state + + @model_state.setter + def model_state(self, state): + new_model_state = self._for_n_devices(state) + if isinstance(self._model_state, list): + self._model_state[0] = new_model_state + else: # weights are a tuple, need to re-create + self._model_state = [new_model_state] + list(self._model_state[1:]) + @property def state(self): return TrainerState(