Skip to content

Commit

Permalink
Merge f7e3800 into dc9e401
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed Jun 28, 2020
2 parents dc9e401 + f7e3800 commit 037740a
Show file tree
Hide file tree
Showing 63 changed files with 404 additions and 579 deletions.
2 changes: 1 addition & 1 deletion trax/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
" model=tiny_transformer_lm,\n",
" loss_fn=trax.layers.CrossEntropyLoss(),\n",
" optimizer=trax.optimizers.Adafactor, # Change optimizer params here.\n",
" lr_schedule=trax.lr.MultifactorSchedule, # Change lr schedule here.\n",
" lr_schedule=trax.lr.constant(0.001), # Change lr schedule here.\n",
" inputs=copy_inputs,\n",
" output_dir=output_dir)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ActorCriticTrainer(rl_training.PolicyTrainer):
def __init__(self, task,
value_model=None,
value_optimizer=None,
value_lr_schedule=lr.MultifactorSchedule,
value_lr_schedule=lr.multifactor,
value_batch_size=64,
value_train_steps_per_epoch=500,
value_evals_per_epoch=1,
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(self, task,
self._value_trainer = supervised.Trainer(
model=value_model,
optimizer=value_optimizer,
lr_schedule=value_lr_schedule,
lr_schedule=value_lr_schedule(),
loss_fn=tl.L2Loss(),
inputs=self._value_inputs,
output_dir=value_output_dir,
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/actor_critic_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ActorCriticJointTrainer(rl_training.RLTrainer):
def __init__(self, task,
joint_model=None,
optimizer=None,
lr_schedule=lr.MultifactorSchedule,
lr_schedule=lr.multifactor,
batch_size=64,
train_steps_per_epoch=500,
supervised_evals_per_epoch=1,
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, task,
self._n_trajectories_per_epoch = n_trajectories_per_epoch
self._max_slice_length = max_slice_length
self._policy_dist = distributions.create_distribution(task.action_space)
self._lr_schedule = lr_schedule
self._lr_schedule = lr_schedule()
self._optimizer = optimizer
self._normalize_advantages = normalize_advantages
self._n_replay_epochs = n_replay_epochs
Expand Down
12 changes: 6 additions & 6 deletions trax/rl/actor_critic_joint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def test_jointppotrainer_cartpole(self):
models.PolicyAndValue,
body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()),
)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')

trainer = actor_critic_joint.PPOJointTrainer(
task,
Expand All @@ -103,8 +103,8 @@ def test_jointawrtrainer_cartpole(self):
models.PolicyAndValue,
body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()),
)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic_joint.AWRJointTrainer(
task,
joint_model=joint_model,
Expand All @@ -124,8 +124,8 @@ def test_jointa2ctrainer_cartpole(self):
models.PolicyAndValue,
body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()),
)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic_joint.A2CJointTrainer(
task,
joint_model=joint_model,
Expand Down
32 changes: 16 additions & 16 deletions trax/rl/actor_critic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def test_sanity_a2ctrainer_cartpole(self):
body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
policy_model = functools.partial(models.Policy, body=body)
value_model = functools.partial(models.Value, body=body)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-4, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-4, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic.A2CTrainer(
task,
n_shared_layers=1,
Expand All @@ -114,8 +114,8 @@ def test_sanity_ppo_cartpole(self):
task = rl_task.RLTask(
'CartPole-v1', initial_trajectories=0, max_steps=200)

lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-3,
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-3,
warmup_steps=100,
factors='constant * linear_warmup')

Expand Down Expand Up @@ -147,8 +147,8 @@ def test_awrtrainer_cartpole(self):
body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
policy_model = functools.partial(models.Policy, body=body)
value_model = functools.partial(models.Value, body=body)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic.AWRTrainer(
task,
n_shared_layers=0,
Expand Down Expand Up @@ -181,8 +181,8 @@ def test_awrtrainer_cartpole_shared(self):
value_model = functools.partial(models.Value, body=body)
# pylint: disable=g-long-lambda
lr = (
lambda h: lr_schedules.MultifactorSchedule(
h, constant=1e-2, warmup_steps=100,
lambda: lr_schedules.multifactor(
constant=1e-2, warmup_steps=100,
factors='constant * linear_warmup')
)
# pylint: enable=g-long-lambda
Expand Down Expand Up @@ -223,8 +223,8 @@ def test_sanity_awrtrainer_transformer_cartpole(self):
d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode)
policy_model = functools.partial(models.Policy, body=body)
value_model = functools.partial(models.Value, body=body)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic.AWRTrainer(
task,
n_shared_layers=0,
Expand Down Expand Up @@ -252,8 +252,8 @@ def test_sampling_awrtrainer_cartpole(self):
body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
policy_model = functools.partial(models.Policy, body=body)
value_model = functools.partial(models.Value, body=body)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic.SamplingAWRTrainer(
task,
n_shared_layers=0,
Expand Down Expand Up @@ -285,8 +285,8 @@ def test_sampling_awrtrainer_cartpole_sample_all_discrete(self):
body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
policy_model = functools.partial(models.Policy, body=body)
value_model = functools.partial(models.Value, body=body)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic.SamplingAWRTrainer(
task,
n_shared_layers=0,
Expand Down Expand Up @@ -318,8 +318,8 @@ def test_sampling_awrtrainer_mountain_acr(self):
body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
policy_model = functools.partial(models.Policy, body=body)
value_model = functools.partial(models.Value, body=body)
lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda
h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda
constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
trainer = actor_critic.SamplingAWRTrainer(
task,
n_shared_layers=0,
Expand Down
18 changes: 9 additions & 9 deletions trax/rl/configs/light_atari.gin
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ Policy.body = @trax.models.AtariCnnBody
# ==============================================================================
Value.body = @trax.models.AtariCnnBody

# Parameters for MultifactorSchedule:
# Parameters for multifactor:
# ==============================================================================
value/MultifactorSchedule.constant = 0.0001
value/MultifactorSchedule.factors = 'constant'
policy/MultifactorSchedule.constant = 0.0001
policy/MultifactorSchedule.factors = 'constant'
value/multifactor.constant = 0.0001
value/multifactor.factors = 'constant'
policy/multifactor.constant = 0.0001
policy/multifactor.factors = 'constant'

# Parameters for RLTask:
# ==============================================================================
Expand All @@ -50,14 +50,14 @@ AWRTrainer.value_model = @trax.models.Value
AWRTrainer.value_optimizer = @trax.optimizers.Adam
AWRTrainer.value_batch_size = 32
AWRTrainer.value_train_steps_per_epoch = 1000
AWRTrainer.value_lr_schedule = @value/MultifactorSchedule
AWRTrainer.value_lr_schedule = @value/multifactor
AWRTrainer.value_evals_per_epoch = 10
AWRTrainer.value_eval_steps = 10
AWRTrainer.policy_model = @trax.models.Policy
AWRTrainer.policy_optimizer = @trax.optimizers.Adam
AWRTrainer.policy_batch_size = 32
AWRTrainer.policy_train_steps_per_epoch = 1000
AWRTrainer.policy_lr_schedule = @policy/MultifactorSchedule
AWRTrainer.policy_lr_schedule = @policy/multifactor
AWRTrainer.policy_evals_per_epoch = 10
AWRTrainer.policy_eval_steps = 10
AWRTrainer.n_trajectories_per_epoch = 10
Expand All @@ -75,12 +75,12 @@ PPOTrainer.value_batch_size = 32
PPOTrainer.value_train_steps_per_epoch = 10
PPOTrainer.value_evals_per_epoch = 1
PPOTrainer.value_eval_steps = 1
PPOTrainer.value_lr_schedule = @value/MultifactorSchedule
PPOTrainer.value_lr_schedule = @value/multifactor
PPOTrainer.policy_model = @trax.models.Policy
PPOTrainer.policy_optimizer = @trax.optimizers.Adam
PPOTrainer.policy_batch_size = 32
PPOTrainer.policy_train_steps_per_epoch = 10
PPOTrainer.policy_lr_schedule = @policy/MultifactorSchedule
PPOTrainer.policy_lr_schedule = @policy/multifactor
PPOTrainer.policy_evals_per_epoch = 1
PPOTrainer.policy_eval_steps = 1
PPOTrainer.advantage_estimator = @trax.rl.advantages.td_lambda
Expand Down
18 changes: 9 additions & 9 deletions trax/rl/configs/light_awr_boxing.gin
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ q_value_n_samples = 18
Adam.weight_decay_rate = 0.0
Adam.clip_grad_norm = 50.0

# Parameters for MultifactorSchedule:
# Parameters for multifactor:
# ==============================================================================
policy/MultifactorSchedule.constant = %policy_lr
policy/MultifactorSchedule.factors = 'constant'
value/MultifactorSchedule.constant = %value_lr
value/MultifactorSchedule.factors = 'constant'
policy/multifactor.constant = %policy_lr
policy/multifactor.factors = 'constant'
value/multifactor.constant = %value_lr
value/multifactor.factors = 'constant'

# Parameters for Momentum:
# ==============================================================================
Expand All @@ -80,14 +80,14 @@ AWRTrainer.value_model = @trax.models.Value
AWRTrainer.value_optimizer = @trax.optimizers.Adam
AWRTrainer.value_batch_size = %batch_size
AWRTrainer.value_train_steps_per_epoch = %value_train_steps
AWRTrainer.value_lr_schedule = @value/MultifactorSchedule
AWRTrainer.value_lr_schedule = @value/multifactor
AWRTrainer.value_evals_per_epoch = 1
AWRTrainer.value_eval_steps = 10
AWRTrainer.policy_model = @trax.models.Policy
AWRTrainer.policy_optimizer = @trax.optimizers.Adam
AWRTrainer.policy_batch_size = %batch_size
AWRTrainer.policy_train_steps_per_epoch = %policy_train_steps
AWRTrainer.policy_lr_schedule = @policy/MultifactorSchedule
AWRTrainer.policy_lr_schedule = @policy/multifactor
AWRTrainer.policy_evals_per_epoch = 1
AWRTrainer.policy_eval_steps = 10
AWRTrainer.n_trajectories_per_epoch = None
Expand All @@ -112,14 +112,14 @@ SamplingAWRTrainer.value_model = @trax.models.Value
SamplingAWRTrainer.value_optimizer = @trax.optimizers.Adam
SamplingAWRTrainer.value_batch_size = %batch_size
SamplingAWRTrainer.value_train_steps_per_epoch = %value_train_steps
SamplingAWRTrainer.value_lr_schedule = @value/MultifactorSchedule
SamplingAWRTrainer.value_lr_schedule = @value/multifactor
SamplingAWRTrainer.value_evals_per_epoch = 2
SamplingAWRTrainer.value_eval_steps = 1
SamplingAWRTrainer.policy_model = @trax.models.Policy
SamplingAWRTrainer.policy_optimizer = @trax.optimizers.Adam
SamplingAWRTrainer.policy_batch_size = %batch_size
SamplingAWRTrainer.policy_train_steps_per_epoch = %policy_train_steps
SamplingAWRTrainer.policy_lr_schedule = @policy/MultifactorSchedule
SamplingAWRTrainer.policy_lr_schedule = @policy/multifactor
SamplingAWRTrainer.policy_evals_per_epoch = 2
SamplingAWRTrainer.policy_eval_steps = 1
SamplingAWRTrainer.n_trajectories_per_epoch = None
Expand Down
8 changes: 4 additions & 4 deletions trax/rl/configs/light_awr_joint_atari.gin
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import trax.rl
# ==============================================================================
PolicyAndValue.body = @trax.models.AtariJointCnnBody

# Parameters for MultifactorSchedule:
# Parameters for multifactor:
# ==============================================================================
MultifactorSchedule.constant = 0.01
MultifactorSchedule.factors = 'constant'
multifactor.constant = 0.01
multifactor.factors = 'constant'

# Parameters for RLTask:
# ==============================================================================
Expand All @@ -40,7 +40,7 @@ AWRJointTrainer.joint_model = @trax.models.PolicyAndValue
AWRJointTrainer.optimizer = @trax.optimizers.Adam
AWRJointTrainer.batch_size = 32
AWRJointTrainer.train_steps_per_epoch = 1000
AWRJointTrainer.lr_schedule = @MultifactorSchedule
AWRJointTrainer.lr_schedule = @multifactor
AWRJointTrainer.n_trajectories_per_epoch = 10
AWRJointTrainer.beta = 1.0
AWRJointTrainer.w_max = 20
Expand Down
8 changes: 4 additions & 4 deletions trax/rl/configs/light_awr_joint_cartpole.gin
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ PureMLP.flatten = False
PureMLP.layer_widths = (128,)
PureMLP.out_activation = True

# Parameters for MultifactorSchedule:
# Parameters for multifactor:
# ==============================================================================
MultifactorSchedule.constant = 0.01
MultifactorSchedule.factors = 'constant'
multifactor.constant = 0.01
multifactor.factors = 'constant'

# Parameters for RLTask:
# ==============================================================================
Expand All @@ -45,7 +45,7 @@ AWRJointTrainer.joint_model = @trax.models.PolicyAndValue
AWRJointTrainer.optimizer = @trax.optimizers.Adam
AWRJointTrainer.batch_size = 32
AWRJointTrainer.train_steps_per_epoch = 1000
AWRJointTrainer.lr_schedule = @MultifactorSchedule
AWRJointTrainer.lr_schedule = @multifactor
AWRJointTrainer.n_trajectories_per_epoch = 10
AWRJointTrainer.beta = 1.0
AWRJointTrainer.w_max = 20
Expand Down
18 changes: 9 additions & 9 deletions trax/rl/configs/light_cartpole.gin
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ PureMLP.flatten = False
PureMLP.layer_widths = (64,)
PureMLP.out_activation = True

# Parameters for MultifactorSchedule:
# Parameters for multifactor:
# ==============================================================================
policy/MultifactorSchedule.constant = 0.0001
policy/MultifactorSchedule.factors = 'constant'
value/MultifactorSchedule.constant = 0.001
value/MultifactorSchedule.factors = 'constant'
policy/multifactor.constant = 0.0001
policy/multifactor.factors = 'constant'
value/multifactor.constant = 0.001
value/multifactor.factors = 'constant'

# Parameters for RLTask:
# ==============================================================================
Expand Down Expand Up @@ -76,14 +76,14 @@ PPOTrainer.value_model = @trax.models.Value
PPOTrainer.value_optimizer = @trax.optimizers.Adam
PPOTrainer.value_batch_size = 32
PPOTrainer.value_train_steps_per_epoch = 10
PPOTrainer.value_lr_schedule = @value/MultifactorSchedule
PPOTrainer.value_lr_schedule = @value/multifactor
PPOTrainer.value_evals_per_epoch = 1
PPOTrainer.value_eval_steps = 1
PPOTrainer.policy_model = @trax.models.Policy
PPOTrainer.policy_optimizer = @trax.optimizers.Adam
PPOTrainer.policy_batch_size = 32
PPOTrainer.policy_train_steps_per_epoch = 10
PPOTrainer.policy_lr_schedule = @policy/MultifactorSchedule
PPOTrainer.policy_lr_schedule = @policy/multifactor
PPOTrainer.policy_evals_per_epoch = 1
PPOTrainer.policy_eval_steps = 1
PPOTrainer.advantage_estimator = @trax.rl.advantages.td_lambda
Expand All @@ -102,14 +102,14 @@ AWRTrainer.value_model = @trax.models.Value
AWRTrainer.value_optimizer = @trax.optimizers.Momentum
AWRTrainer.value_batch_size = 256
AWRTrainer.value_train_steps_per_epoch = 40
AWRTrainer.value_lr_schedule = @value/MultifactorSchedule
AWRTrainer.value_lr_schedule = @value/multifactor
AWRTrainer.value_evals_per_epoch = 1
AWRTrainer.value_eval_steps = 10
AWRTrainer.policy_model = @trax.models.Policy
AWRTrainer.policy_optimizer = @trax.optimizers.Momentum
AWRTrainer.policy_batch_size = 256
AWRTrainer.policy_train_steps_per_epoch = 1080
AWRTrainer.policy_lr_schedule = @policy/MultifactorSchedule
AWRTrainer.policy_lr_schedule = @policy/multifactor
AWRTrainer.policy_evals_per_epoch = 1
AWRTrainer.policy_eval_steps = 10
AWRTrainer.n_trajectories_per_epoch = 10
Expand Down
10 changes: 5 additions & 5 deletions trax/rl/configs/light_cartpole_transformer.gin
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ TransformerDecoder.d_ff = %depth
TransformerDecoder.n_layers = 2
TransformerDecoder.n_heads = 2

# Parameters for MultifactorSchedule:
# Parameters for multifactor:
# ==============================================================================
MultifactorSchedule.constant = 0.001
MultifactorSchedule.factors = 'constant'
multifactor.constant = 0.001
multifactor.factors = 'constant'

# Parameters for RLTask:
# ==============================================================================
Expand All @@ -52,14 +52,14 @@ AWRTrainer.value_model = @trax.models.Value
AWRTrainer.value_optimizer = @trax.optimizers.Adam
AWRTrainer.value_batch_size = 32
AWRTrainer.value_train_steps_per_epoch = 200
AWRTrainer.value_lr_schedule = @value/MultifactorSchedule
AWRTrainer.value_lr_schedule = @value/multifactor
AWRTrainer.value_evals_per_epoch = 2
AWRTrainer.value_eval_steps = 1
AWRTrainer.policy_model = @trax.models.Policy
AWRTrainer.policy_optimizer = @trax.optimizers.Adam
AWRTrainer.policy_batch_size = 32
AWRTrainer.policy_train_steps_per_epoch = 500
AWRTrainer.policy_lr_schedule = @policy/MultifactorSchedule
AWRTrainer.policy_lr_schedule = @policy/multifactor
AWRTrainer.policy_evals_per_epoch = 2
AWRTrainer.policy_eval_steps = 1
AWRTrainer.n_trajectories_per_epoch = 200
Expand Down

0 comments on commit 037740a

Please sign in to comment.