Skip to content

Commit

Permalink
Merge ea308d6 into 7824d88
Browse files Browse the repository at this point in the history
  • Loading branch information
AboudyKreidieh committed Jun 21, 2020
2 parents 7824d88 + ea308d6 commit 1a76388
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 38 deletions.
37 changes: 4 additions & 33 deletions examples/train.py
Expand Up @@ -213,45 +213,17 @@ def train_rllib(submodule, flags):
run_experiments({flow_params["exp_tag"]: exp_config})


def train_h_baselines(flow_params, args, multiagent):
def train_h_baselines(env_name, args, multiagent):
"""Train policies using SAC and TD3 with h-baselines."""
from hbaselines.algorithms import OffPolicyRLAlgorithm
from hbaselines.utils.train import parse_options, get_hyperparameters
from hbaselines.envs.mixed_autonomy import FlowEnv

flow_params = deepcopy(flow_params)

# Get the command-line arguments that are relevant here
args = parse_options(description="", example_usage="", args=args)

# the base directory that the logged data will be stored in
base_dir = "training_data"

# Create the training environment.
env = FlowEnv(
flow_params,
multiagent=multiagent,
shared=args.shared,
maddpg=args.maddpg,
render=args.render,
version=0
)

# Create the evaluation environment.
if args.evaluate:
eval_flow_params = deepcopy(flow_params)
eval_flow_params['env'].evaluate = True
eval_env = FlowEnv(
eval_flow_params,
multiagent=multiagent,
shared=args.shared,
maddpg=args.maddpg,
render=args.render_eval,
version=1
)
else:
eval_env = None

for i in range(args.n_training):
# value of the next seed
seed = args.seed + i
Expand Down Expand Up @@ -299,8 +271,8 @@ def train_h_baselines(flow_params, args, multiagent):
# Create the algorithm object.
alg = OffPolicyRLAlgorithm(
policy=policy,
env=env,
eval_env=eval_env,
env="flow:{}".format(env_name),
eval_env="flow:{}".format(env_name) if args.evaluate else None,
**hp
)

Expand Down Expand Up @@ -393,8 +365,7 @@ def main(args):
elif flags.rl_trainer.lower() == "stable-baselines":
train_stable_baselines(submodule, flags)
elif flags.rl_trainer.lower() == "h-baselines":
flow_params = submodule.flow_params
train_h_baselines(flow_params, args, multiagent)
train_h_baselines(flags.exp_config, args, multiagent)
else:
raise ValueError("rl_trainer should be either 'rllib', 'h-baselines', "
"or 'stable-baselines'.")
Expand Down
10 changes: 5 additions & 5 deletions tests/fast_tests/test_examples.py
Expand Up @@ -229,22 +229,22 @@ class TestHBaselineExamples(unittest.TestCase):
confirming that it runs.
"""
@staticmethod
def run_exp(flow_params, multiagent):
def run_exp(env_name, multiagent):
train_h_baselines(
flow_params=flow_params,
env_name=env_name,
args=[
flow_params["env_name"].__name__,
env_name,
"--initial_exploration_steps", "1",
"--total_steps", "10"
],
multiagent=multiagent,
)

def test_singleagent_ring(self):
self.run_exp(singleagent_ring.copy(), multiagent=False)
self.run_exp("singleagent_ring", multiagent=False)

def test_multiagent_ring(self):
self.run_exp(multiagent_ring.copy(), multiagent=True)
self.run_exp("multiagent_ring", multiagent=True)


class TestRllibExamples(unittest.TestCase):
Expand Down

0 comments on commit 1a76388

Please sign in to comment.