/
BBF.gin
executable file
·74 lines (67 loc) · 3.05 KB
/
BBF.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# SPR (Schwarzer et al, 2021)
import dopamine.jax.networks
import dopamine.discrete_domains.gym_lib
import dopamine.discrete_domains.run_experiment
import bigger_better_faster.bbf.spr_networks
import bigger_better_faster.bbf.agents.spr_agent
import bigger_better_faster.bbf.replay_memory.subsequence_replay_buffer
JaxDQNAgent.gamma = 0.997
JaxDQNAgent.min_replay_history = 2000
JaxDQNAgent.update_period = 1
JaxDQNAgent.target_update_period = 1
JaxDQNAgent.epsilon_train = 0.00
JaxDQNAgent.epsilon_eval = 0.001
JaxDQNAgent.epsilon_decay_period = 2001 # DrQ
JaxDQNAgent.optimizer = 'adam'
BBFAgent.noisy = False
BBFAgent.dueling = True
BBFAgent.double_dqn = True
BBFAgent.distributional = True
BBFAgent.num_atoms = 51
BBFAgent.update_horizon = 3
BBFAgent.max_update_horizon = 10
BBFAgent.min_gamma = 0.97
BBFAgent.cycle_steps = 10_000
BBFAgent.reset_every = 20_000 # Change if you change the replay ratio
BBFAgent.shrink_perturb_keys = "encoder,transition_model"
BBFAgent.shrink_factor = 0.5
BBFAgent.perturb_factor = 0.5
BBFAgent.no_resets_after = 100_000 # Need to change if training longer
BBFAgent.log_every = 100
BBFAgent.replay_ratio = 64
BBFAgent.batches_to_group = 2
BBFAgent.batch_size = 32
BBFAgent.spr_weight = 5
BBFAgent.jumps = 5
BBFAgent.data_augmentation = True
BBFAgent.replay_scheme = 'prioritized'
BBFAgent.half_precision = False
BBFAgent.network = @bbf.spr_networks.RainbowDQNNetwork
BBFAgent.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon
BBFAgent.learning_rate = 0.0001
BBFAgent.encoder_learning_rate = 0.0001
BBFAgent.target_update_tau = 0.005
BBFAgent.target_action_selection = True
bbf.spr_networks.RainbowDQNNetwork.renormalize = True
bbf.spr_networks.RainbowDQNNetwork.hidden_dim = 2048
bbf.spr_networks.RainbowDQNNetwork.encoder_type = "impala"
bbf.spr_networks.RainbowDQNNetwork.width_scale = 4
bbf.spr_networks.ImpalaCNN.num_blocks = 2
# Note these parameters are from DER (van Hasselt et al, 2019)
bbf.agents.spr_agent.create_scaling_optimizer.eps = 0.00015
bbf.agents.spr_agent.create_scaling_optimizer.weight_decay = 0.1
DataEfficientAtariRunner.game_name = 'ChopperCommand'
# Atari 100K benchmark doesn't use sticky actions.
atari_lib.create_atari_environment.sticky_actions = False
AtariPreprocessing.terminal_on_life_loss = True
Runner.num_iterations = 1
Runner.training_steps = 100000 # agent steps
DataEfficientAtariRunner.num_eval_episodes = 100 # agent episodes
DataEfficientAtariRunner.num_eval_envs = 100 # agent episodes
DataEfficientAtariRunner.num_train_envs = 1 # agent episodes
DataEfficientAtariRunner.max_noops = 30
Runner.max_steps_per_episode = 27000 # agent steps
bbf.replay_memory.subsequence_replay_buffer.PrioritizedJaxSubsequenceParallelEnvReplayBuffer.replay_capacity = 200000
bbf.replay_memory.subsequence_replay_buffer.PrioritizedJaxSubsequenceParallelEnvReplayBuffer.n_envs = 1 # agent episodes
bbf.replay_memory.subsequence_replay_buffer.JaxSubsequenceParallelEnvReplayBuffer.replay_capacity = 200000
bbf.replay_memory.subsequence_replay_buffer.JaxSubsequenceParallelEnvReplayBuffer.n_envs = 1 # agent episodes