Skip to content

Commit

Permalink
Training with pong and DQN. The same parameters also work for boxing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 332792528
  • Loading branch information
henrykmichalewski authored and Copybara-Service committed Sep 21, 2020
1 parent 84fe848 commit 903a70c
Showing 1 changed file with 55 additions and 5 deletions.
60 changes: 55 additions & 5 deletions trax/rl/configs/light_atari.gin
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ import trax.models
import trax.optimizers
import trax.rl

# Parameters for RMSProp:
# ==============================================================================
RMSProp.clip_grad_norm = 0.5

# Parameters for Adam:
# ==============================================================================
Adam.clip_grad_norm = 0.5

# Parameters for the AtariCnnBody:
# ==============================================================================
AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit
AtariCnnBody.n_frames = 1
AtariCnnBody.padding = 'VALID'

# Parameters for Policy:
# ==============================================================================
Policy.body = @trax.models.AtariCnnBody
Expand All @@ -25,25 +39,61 @@ Policy.body = @trax.models.AtariCnnBody
# ==============================================================================
Value.body = @trax.models.AtariCnnBody

# Parameters for Value:
# ==============================================================================
Quality.body = @trax.models.AtariCnnBody

value_lr = 0.00025
initial_exploration_rate = 1.
exploration_decay_factor = 0.998
minimum_exploration = 0.1
steps_per_decay = 1


# Parameters for multifactor:
# ==============================================================================
value/multifactor.constant = 0.0001
value/multifactor.factors = 'constant'
policy/multifactor.constant = 0.0001
policy/multifactor.factors = 'constant'
exploration_rate/multifactor.constant = %initial_exploration_rate
exploration_rate/multifactor.decay_factor = %exploration_decay_factor
exploration_rate/multifactor.minimum = %minimum_exploration
exploration_rate/multifactor.steps_per_decay = %steps_per_decay
exploration_rate/multifactor.factors = 'constant * decay_every'

# Parameters for RLTask:
# ==============================================================================
RLTask.env = "freeway"
RLTask.initial_trajectories = 100
RLTask.gamma = 0.999
RLTask.max_steps = 200
RLTask.env = "pong"
RLTask.initial_trajectories = 10
RLTask.gamma = 0.99
RLTask.max_steps = 2000
RLTask.dm_suite = True
RLTask.num_stacked_frames = 4

# Parameters for td_lambda:
# ==============================================================================
td_lambda.lambda_ = 0.95

# Parameters for DQN:
# ==============================================================================
DQN.value_optimizer = @trax.optimizers.Adam
DQN.value_body = @trax.models.AtariCnnBody
DQN.value_batch_size = 32
DQN.value_train_steps_per_epoch = 500
DQN.value_evals_per_epoch = 1
DQN.value_eval_steps = 1
DQN.exploration_rate = @exploration_rate/multifactor
DQN.value_lr_schedule = @value/multifactor
DQN.n_eval_episodes = 0
DQN.only_eval = False
DQN.n_replay_epochs = 100
DQN.max_slice_length = 4
DQN.sync_freq = 1000
DQN.scale_value_targets = False
DQN.n_interactions_per_epoch = 2000
DQN.advantage_estimator = @trax.rl.advantages.td_k

# Parameters for AWR:
# ==============================================================================
AWR.value_model = @trax.models.Value
Expand Down Expand Up @@ -90,4 +140,4 @@ PPO.n_trajectories_per_epoch = 10
# ==============================================================================
train_rl.light_rl = True
train_rl.light_rl_trainer = @trax.rl.AWR
train_rl.n_epochs = 5000
train_rl.n_epochs = 50000

0 comments on commit 903a70c

Please sign in to comment.