/
train_baseline_a3c.py
195 lines (170 loc) · 6.85 KB
/
train_baseline_a3c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import ray
from ray import tune
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.models import ModelCatalog
from ray.tune import run_experiments
from ray.tune.registry import register_env
import tensorflow as tf
from social_dilemmas.envs.harvest import HarvestEnv
from social_dilemmas.envs.cleanup import CleanupEnv
from models.conv_to_fc_net import ConvToFCNet
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'exp_name', 'train_baseline_a3c_cleanup',
'Name of the ray_results experiment directory where results are stored.')
tf.app.flags.DEFINE_string(
'env', 'cleanup',
'Name of the environment to rollout. Can be cleanup or harvest.')
tf.app.flags.DEFINE_integer(
'num_agents', 5,
'Number of agent policies')
tf.app.flags.DEFINE_integer(
'num_cpus', 14,
'Number of available CPUs')
tf.app.flags.DEFINE_integer(
'num_gpus', 0,
'Number of available GPUs')
tf.app.flags.DEFINE_boolean(
'use_gpus_for_workers', False,
'Set to true to run workers on GPUs rather than CPUs')
tf.app.flags.DEFINE_boolean(
'use_gpu_for_driver', False,
'Set to true to run driver on GPU rather than CPU.')
tf.app.flags.DEFINE_float(
'num_workers_per_device', 1,
'Number of workers to place on a single device (CPU or GPU)')
tf.app.flags.DEFINE_boolean(
'resume', False,
'Set to true to resume a previously stopped experiment.')
tf.app.flags.DEFINE_boolean(
'tune', True,
'Set to true to tune hyperparameters.')
harvest_default_params = {
'lr_init': 0.00136,
'lr_final': 0.000028,
'entropy_coeff': .000687}
cleanup_default_params = {
'lr_init': 0.00126,
'lr_final': 0.000012,
'entropy_coeff': .00176}
def setup(env, hparams, num_cpus, num_gpus, num_agents, use_gpus_for_workers=False,
use_gpu_for_driver=False, num_workers_per_device=1, tune_hparams=False):
if env == 'harvest':
def env_creator(_):
return HarvestEnv(num_agents=num_agents)
single_env = HarvestEnv()
hparams = harvest_default_params
else:
def env_creator(_):
return CleanupEnv(num_agents=num_agents)
single_env = CleanupEnv()
hparams = cleanup_default_params
env_name = env + "_env"
register_env(env_name, env_creator)
obs_space = single_env.observation_space
act_space = single_env.action_space
# Each policy can have a different configuration (including custom model)
def gen_policy():
return (A3CPolicyGraph, obs_space, act_space, {})
# Setup PPO with an ensemble of `num_policies` different policy graphs
policy_graphs = {}
for i in range(num_agents):
policy_graphs['agent-' + str(i)] = gen_policy()
def policy_mapping_fn(agent_id):
return agent_id
# register the custom model
model_name = "conv_to_fc_net"
ModelCatalog.register_custom_model(model_name, ConvToFCNet)
algorithm = 'A3C'
agent_cls = get_agent_class(algorithm)
config = agent_cls._default_config.copy()
# information for replay
config['env_config']['func_create'] = tune.function(env_creator)
config['env_config']['env_name'] = env_name
config['env_config']['run'] = algorithm
# Calculate device configurations
gpus_for_driver = int(use_gpu_for_driver)
cpus_for_driver = 1 - gpus_for_driver
if use_gpus_for_workers:
spare_gpus = (num_gpus - gpus_for_driver)
num_workers = int(spare_gpus * num_workers_per_device)
num_gpus_per_worker = spare_gpus / num_workers
num_cpus_per_worker = 0
else:
spare_cpus = (num_cpus - cpus_for_driver)
num_workers = int(spare_cpus * num_workers_per_device)
num_gpus_per_worker = 0
num_cpus_per_worker = int(spare_cpus / num_workers)
# hyperparams
if tune_hparams:
config.update({
"train_batch_size": 128,
"horizon": 1000,
"lr_schedule": [[0, tune.grid_search([5e-4, 5e-3])],
[20000000, tune.grid_search([5e-4, 5e-5])]],
"num_workers": num_workers,
"num_gpus": gpus_for_driver, # The number of GPUs for the driver
"num_cpus_for_driver": cpus_for_driver,
"num_gpus_per_worker": num_gpus_per_worker, # Can be a fraction
"num_cpus_per_worker": num_cpus_per_worker, # Can be a fraction
"entropy_coeff": tune.grid_search([0, -1e-1, -1e-2]),
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(policy_mapping_fn),
},
"model": {"custom_model": "conv_to_fc_net", "use_lstm": True,
"lstm_cell_size": 128}
})
else:
config.update({
#"train_batch_size": 128,
"horizon": 1000,
# "lr_schedule": [[0, hparams['lr_init']],
# [20000000, hparams['lr_final']]],
"num_workers": num_workers,
"num_gpus": gpus_for_driver, # The number of GPUs for the driver
"num_cpus_for_driver": cpus_for_driver,
"num_gpus_per_worker": num_gpus_per_worker, # Can be a fraction
"num_cpus_per_worker": num_cpus_per_worker, # Can be a fraction
# "entropy_coeff": hparams['entropy_coeff'],
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(policy_mapping_fn),
},
"model": {"custom_model": "conv_to_fc_net", "use_lstm": True,
"lstm_cell_size": 128}
})
return algorithm, env_name, config
def main(unused_argv):
# ray.init(num_cpus=FLAGS.num_cpus, object_store_memory=int(10e10),
# redis_max_memory=int(20e10))
ray.init(redis_address="localhost:6379")
if FLAGS.env == 'harvest':
hparams = harvest_default_params
else:
hparams = cleanup_default_params
alg_run, env_name, config = setup(FLAGS.env, hparams, FLAGS.num_cpus,
FLAGS.num_gpus, FLAGS.num_agents,
FLAGS.use_gpus_for_workers,
FLAGS.use_gpu_for_driver,
FLAGS.num_workers_per_device, FLAGS.tune)
if FLAGS.exp_name is None:
exp_name = FLAGS.env + '_A3C'
else:
exp_name = FLAGS.exp_name
print('Commencing experiment', exp_name)
run_experiments({
exp_name: {
"run": alg_run,
"env": env_name,
"stop": {
"training_iteration": 20000
},
'checkpoint_freq': 500,
"config": config,
'upload_dir': 's3://njaques.experiments/sixth_reproduction/causal_influence_baseline_harvest'
}
}, resume=FLAGS.resume)
if __name__ == '__main__':
tf.app.run(main)