Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 87 additions & 21 deletions trax/rl/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from trax import layers as tl
from trax import lr_schedules as lr
from trax import shapes
from trax import supervised
from trax.math import numpy as jnp
from trax.rl import advantages as rl_advantages
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, task,
n_shared_layers=0,
added_policy_slice_length=0,
n_replay_epochs=1,
scale_value_targets=False,
**kwargs): # Arguments of PolicyTrainer come here.
"""Configures the actor-critic Trainer.

Expand All @@ -73,6 +75,8 @@ def __init__(self, task,
have maximum length set by max_slice_length in **kwargs
n_replay_epochs: how many last epochs to take into the replay buffer;
only makes sense for off-policy algorithms
scale_value_targets: whether to scale targets for the value function by
1 / (1 - gamma)
**kwargs: arguments for PolicyTrainer super-class
"""
self._n_shared_layers = n_shared_layers
Expand All @@ -89,6 +93,14 @@ def __init__(self, task,
self._added_policy_slice_length = added_policy_slice_length
self._n_replay_epochs = n_replay_epochs

if scale_value_targets:
self._value_network_scale = 1 / (1 - self._task.gamma)
else:
self._value_network_scale = 1

self._value_eval_model = value_model(mode='eval')
self._value_eval_model.init(self._value_model_signature)

# Initialize training of the value function.
value_output_dir = kwargs.get('output_dir', None)
if value_output_dir is not None:
Expand All @@ -106,13 +118,18 @@ def __init__(self, task,
inputs=self._value_inputs,
output_dir=value_output_dir,
metrics={'value_loss': tl.L2Loss()})
self._value_eval_model = value_model(mode='eval')
value_batch = next(self.value_batches_stream())
self._value_eval_model.init(value_batch)

# Initialize policy training.
super(ActorCriticTrainer, self).__init__(task, **kwargs)

@property
def _value_model_signature(self):
obs_sig = shapes.signature(self._task.observation_space)
target_sig = mask_sig = shapes.ShapeDtype(
shape=(1, 1, 1),
)
return (obs_sig.replace(shape=(1, 1) + obs_sig.shape), target_sig, mask_sig)

@property
def _replay_epochs(self):
if self.on_policy:
Expand All @@ -124,15 +141,39 @@ def _replay_epochs(self):

def value_batches_stream(self):
"""Use the RLTask self._task to create inputs to the value model."""
max_slice_length = self._max_slice_length + self._added_policy_slice_length
for np_trajectory in self._task.trajectory_batch_stream(
self._value_batch_size, max_slice_length=self._max_slice_length,
epochs=self._replay_epochs):
self._value_batch_size,
max_slice_length=max_slice_length,
min_slice_length=(1 + self._added_policy_slice_length),
epochs=self._replay_epochs,
):
values = self._value_eval_model(
np_trajectory.observations, n_accelerators=1
) * self._value_network_scale
values = np.squeeze(values, axis=2) # Remove the singleton depth dim.

# TODO(pkozakowski): Add some shape assertions and docs.
# Calculate targets based on the advantages over the target network - this
# allows TD learning for value networks.
advantages = self._advantage_estimator(
np_trajectory.rewards, np_trajectory.returns, values,
gamma=self._task.gamma,
n_extra_steps=self._added_policy_slice_length,
)
length = advantages.shape[1]
values = values[:, :length]
target_returns = values + advantages

# Insert an extra depth dimension, so the target shape is consistent with
# the network output shape.
yield (
np_trajectory.observations, # Inputs to the value model.
np_trajectory.returns[:, :, None], # Targets: regress to returns.
np_trajectory.mask[:, :, None], # Mask to zero-out padding.
# Inputs: observations.
np_trajectory.observations[:, :length],
# Targets: computed returns.
target_returns[:, :, None] / self._value_network_scale,
# Mask to zero-out padding.
np_trajectory.mask[:, :length, None],
)

def policy_inputs(self, trajectory, values):
Expand All @@ -159,9 +200,9 @@ def policy_batches_stream(self):
epochs=self._replay_epochs,
max_slice_length=max_slice_length,
include_final_state=False):
value_model = self._value_eval_model
value_model.weights = self._value_trainer.model_weights
values = value_model(np_trajectory.observations, n_accelerators=1)
values = self._value_eval_model(
np_trajectory.observations, n_accelerators=1
) * self._value_network_scale
values = np.squeeze(values, axis=2) # Remove the singleton depth dim.
if len(values.shape) != 2:
raise ValueError('Values are expected to have shape ' +
Expand All @@ -173,6 +214,19 @@ def policy_batches_stream(self):

def train_epoch(self):
"""Trains RL for one epoch."""
# Copy policy state accumulated during data collection to the trainer.
self._policy_trainer.model_state = self._policy_collect_model.state

# Copy policy weights and state to value trainer.
if self._n_shared_layers > 0:
_copy_model_weights_and_state(
0, self._n_shared_layers, self._policy_trainer, self._value_trainer
)

# Update the target value network.
self._value_eval_model.weights = self._value_trainer.model_weights
self._value_eval_model.state = self._value_trainer.model_state

n_value_evals = rl_training.remaining_evals(
self._value_trainer.step,
self._epoch,
Expand All @@ -183,9 +237,11 @@ def train_epoch(self):
self._value_train_steps_per_epoch // self._value_evals_per_epoch,
self._value_eval_steps,
)
if self._n_shared_layers > 0: # Copy value weights to policy trainer.
_copy_model_weights(0, self._n_shared_layers,
self._value_trainer, self._policy_trainer)
# Copy value weights and state to policy trainer.
if self._n_shared_layers > 0:
_copy_model_weights_and_state(
0, self._n_shared_layers, self._value_trainer, self._policy_trainer
)
n_policy_evals = rl_training.remaining_evals(
self._policy_trainer.step,
self._epoch,
Expand All @@ -196,31 +252,41 @@ def train_epoch(self):
n_policy_evals < self._policy_evals_per_epoch)
should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value
if should_copy_weights:
_copy_model_weights(0, self._n_shared_layers,
self._value_trainer, self._policy_trainer)
_copy_model_weights_and_state(
0, self._n_shared_layers, self._value_trainer, self._policy_trainer
)

# Update the target value network.
self._value_eval_model.weights = self._value_trainer.model_weights
self._value_eval_model.state = self._value_trainer.model_state

for _ in range(n_policy_evals):
self._policy_trainer.train_epoch(
self._policy_train_steps_per_epoch // self._policy_evals_per_epoch,
self._policy_eval_steps,
)
if self._n_shared_layers > 0: # Copy policy weights to value trainer.
_copy_model_weights(0, self._n_shared_layers,
self._policy_trainer, self._value_trainer)

def close(self):
self._value_trainer.close()
super().close()


def _copy_model_weights(start, end, from_trainer, to_trainer, # pylint: disable=invalid-name
copy_optimizer_slots=True):
def _copy_model_weights_and_state( # pylint: disable=invalid-name
start, end, from_trainer, to_trainer, copy_optimizer_slots=False
):
"""Copy model weights[start:end] from from_trainer to to_trainer."""
from_weights = from_trainer.model_weights
to_weights = to_trainer.model_weights
shared_weights = from_weights[start:end]
to_weights[start:end] = shared_weights
to_trainer.model_weights = to_weights

from_state = from_trainer.model_state
to_state = to_trainer.model_state
shared_state = from_state[start:end]
to_state[start:end] = shared_state
to_trainer.model_state = to_state

if copy_optimizer_slots:
# TODO(lukaszkaiser): make a nicer API in Trainer to support this.
# Currently we use the hack below. Note [0] since that's the model w/o loss.
Expand Down
11 changes: 9 additions & 2 deletions trax/rl/actor_critic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from trax import optimizers as opt
from trax import test_utils
from trax.rl import actor_critic
from trax.rl import advantages
from trax.rl import task as rl_task


Expand Down Expand Up @@ -161,7 +162,10 @@ def test_awrtrainer_cartpole(self):
policy_lr_schedule=lr,
policy_batch_size=32,
policy_train_steps_per_epoch=200,
collect_per_epoch=10)
collect_per_epoch=10,
advantage_estimator=advantages.monte_carlo,
advantage_normalization=False,
)
trainer.run(1)
self.assertEqual(1, trainer.current_epoch)
self.assertGreater(trainer.avg_returns[-1], 50.0)
Expand Down Expand Up @@ -189,7 +193,10 @@ def test_awrtrainer_cartpole_shared(self):
policy_lr_schedule=lr,
policy_batch_size=32,
policy_train_steps_per_epoch=200,
collect_per_epoch=10)
collect_per_epoch=10,
advantage_estimator=advantages.monte_carlo,
advantage_normalization=False,
)
trainer.run(1)
self.assertEqual(1, trainer.current_epoch)
self.assertGreater(trainer.avg_returns[-1], 50.0)
Expand Down
29 changes: 22 additions & 7 deletions trax/rl/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, observation, action=None, reward=None, log_prob=None):
'log_probs',
'rewards',
'returns',
'mask'
'mask',
])


Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(self, env=gin.REQUIRED, initial_trajectories=1, gamma=0.99,
initial_trajectories = [
# Whatever we gather here is intended to be removed
# in PolicyTrainer. Here we just gather some example inputs.
self.play(_random_policy(self.action_space), max_steps=2)
self.play(_random_policy(self.action_space))
]

if isinstance(initial_trajectories, list):
Expand Down Expand Up @@ -328,11 +328,18 @@ def action_space(self):
else:
return self._env.action_space

def observation_shape(self):
@property
def observation_space(self):
"""Returns the env's observation space in a Gym interface."""
if self._dm_suite:
return self._env.observation_spec().shape
return gym.spaces.Box(
shape=self._env.observation_spec().shape,
dtype=self._env.observation_spec().dtype,
low=float('-inf'),
high=float('+inf'),
)
else:
return self._env.observation_space.shape
return self._env.observation_space

@property
def trajectories(self):
Expand Down Expand Up @@ -480,6 +487,7 @@ def n_slices(t):

def trajectory_batch_stream(self, batch_size, epochs=None,
max_slice_length=None,
min_slice_length=None,
include_final_state=False,
sample_trajectories_uniformly=False):
"""Return a stream of trajectory batches from the specified epochs.
Expand All @@ -491,14 +499,16 @@ def trajectory_batch_stream(self, batch_size, epochs=None,
batch_size: the size of the batches to return
epochs: a list of epochs to use; we use all epochs if None
max_slice_length: maximum length of the slices of trajectories to return
min_slice_length: minimum length of the slices of trajectories to return
include_final_state: whether to include slices with the final state of
the trajectory which may have no action and reward
sample_trajectories_uniformly: whether to sample trajectories uniformly,
or proportionally to the number of slices in each trajectory (default)

Yields:
batches of trajectory slices sampled uniformly from all slices of length
upto max_slice_length in all specified epochs
at least min_slice_length and up to max_slice_length in all specified
epochs
"""
def pad(tensor_list):
max_len = max([t.shape[0] for t in tensor_list])
Expand All @@ -513,14 +523,19 @@ def pad(tensor_list):
for t in self.trajectory_stream(
epochs, max_slice_length,
include_final_state, sample_trajectories_uniformly):
# TODO(pkozakowski): Instead sample the trajectories out of those with
# the minimum length.
if min_slice_length is not None and len(t) < min_slice_length:
continue

cur_batch.append(t)
if len(cur_batch) == batch_size:
obs, act, logp, rew, ret, _ = zip(*[t.to_np(self._timestep_to_np)
for t in cur_batch])
# Where act, logp, rew and ret will usually have the following shape:
# [batch_size, trajectory_length-1], which we call [B, L-1].
# Observations are more complex and will usuall be [B, L] + S where S
# is the shape of the observation space (self.observation_shape).
# is the shape of the observation space (self.observation_space.shape).
yield TrajectoryNp(
pad(obs), pad(act), pad(logp), pad(rew), pad(ret),
pad([np.ones(a.shape[:1]) for a in act]))
Expand Down
7 changes: 7 additions & 0 deletions trax/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def __len__(self):
def as_tuple(self):
return self.shape, self.dtype

def replace(self, **kwargs):
"""Creates a copy of the object with some parameters replaced."""
return type(self)(
shape=kwargs.pop('shape', self.shape),
dtype=kwargs.pop('dtype', self.dtype),
)


def signature(obj):
"""Returns a `ShapeDtype` signature for the given `obj`.
Expand Down
17 changes: 17 additions & 0 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ def model_weights(self, weights):
new_weights = [new_model_weights] + list(self._opt_state.weights[1:])
self._opt_state = self._opt_state._replace(weights=new_weights)

@property
def model_state(self):
# Currently we need to pick [0] as we ignore loss state (empty).
state = self._model_state[0]
if self.n_devices > 1:
unreplicate = lambda x: x[0]
state = math.nested_map(unreplicate, state)
return state

@model_state.setter
def model_state(self, state):
new_model_state = self._for_n_devices(state)
if isinstance(self._model_state, list):
self._model_state[0] = new_model_state
else: # weights are a tuple, need to re-create
self._model_state = [new_model_state] + list(self._model_state[1:])

@property
def state(self):
return TrainerState(
Expand Down