Skip to content

Commit

Permalink
Merge 2d2badb into 7d632c7
Browse files Browse the repository at this point in the history
  • Loading branch information
kanaadp committed Mar 6, 2020
2 parents 7d632c7 + 2d2badb commit 328197b
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions examples/train.py
Expand Up @@ -65,6 +65,9 @@ def parse_args(args):
parser.add_argument(
'--rollout_size', type=int, default=1000,
help='How many steps are in a training batch.')
parser.add_argument(
'--checkpoint_path', type=str, default=None,
help='Directory with checkpoint to restore training from.')

return parser.parse_known_args(args)[0]

Expand Down Expand Up @@ -199,22 +202,24 @@ def setup_exps_rllib(flow_params,
flow_params, n_cpus, n_rollouts,
policy_graphs, policy_mapping_fn, policies_to_train)

ray.init(num_cpus=n_cpus + 1)
trials = run_experiments({
flow_params["exp_tag"]: {
"run": alg_run,
"env": gym_name,
"config": {
**config
},
"checkpoint_freq": 20,
"checkpoint_at_end": True,
"max_failures": 999,
"stop": {
"training_iteration": 200,
},
}
})
ray.init(num_cpus=n_cpus + 1, object_store_memory=200 * 1024 * 1024)
exp_config = {
"run": alg_run,
"env": gym_name,
"config": {
**config
},
"checkpoint_freq": 20,
"checkpoint_at_end": True,
"max_failures": 999,
"stop": {
"training_iteration": flags.num_steps,
},
}

if flags.checkpoint_path is not None:
exp_config['restore'] = flags.checkpoint_path
trials = run_experiments({flow_params["exp_tag"]: exp_config})

elif flags.rl_trainer == "Stable-Baselines":
flow_params = submodule.flow_params
Expand Down

0 comments on commit 328197b

Please sign in to comment.