-
Notifications
You must be signed in to change notification settings - Fork 43
Closed
Description
Is their an example showing how to train a specific option policy?
For example, from the agent flow tutorial, how can we setup training for the ExamplePolicy. The problem being that the output of step in the main_loop may have a different observation_spec and action_spec than the policy.
# Stubs for pulling observation and sending action to some external system.
observation_cb = ExampleObservationUpdater()
action_cb = ExampleActionSender()
# Create an environment that forwards the observation and action calls.
env = ProxyEnvironment(observation_cb, action_cb)
# Stub policy that runs the desired agent.
policy = ExamplePolicy(action_cb.action_spec(), "agent")
# Wrap policy into an agent that logs to the terminal.
task = ExampleSubTask(env.observation_spec(), action_cb.action_spec(), 10)
logger = print_logger.PrintLogger()
aggregator = subtask_logger.EpisodeReturnAggregator()
logging_observer = subtask_logger.SubTaskLogger(logger, aggregator)
agent = subtask.SubTaskOption(task, policy, [logging_observer])
reset_op = ExampleScriptedOption(action_cb.action_spec(), "reset", 3)
main_loop = loop_ops.Repeat(5, sequence.Sequence([reset_op, agent]))
# Run the episode.
timestep = env.reset()
while True:
action = main_loop.step(timestep)
timestep = env.step(action)
# Terminate if the environment or main_loop requests it.
if timestep.last() or (main_loop.pterm(timestep) > np.random.rand()):
if not timestep.last():
termination_timestep = timestep._replace(step_type=dm_env.StepType.LAST)
main_loop.step(termination_timestep)
break
Metadata
Metadata
Assignees
Labels
No labels