Skip to content

Commit

Permalink
Model-Based RL: Player (tensorflow#1330)
Browse files Browse the repository at this point in the history
* SimulatedEnv with gym-like interface.

* Initial Player

* Player: Add reward header, keybord reset, CLI, move to player.py

* Player: add WAIT mode, few CLI options.

* Introduce Policy Inferencer

* Recording videos for ppo and player, some refactoring.

* Player refactor. Add real env recording with PPO agent.

* Extend CLI documentation. Remove some imports.

* Pylint

* Extend documentation.

* Correct dopamine import.

* Move gym.utils.play to global imports.

* Remove SimulatedEnv (unnecesarry wrapper for FlatBatchEnv<SimulatedBatchGymEnv>)

* Replace join_and_check with os.path.join.

* Move generation of initial_frame_chooser function to rl_utils.

* Move make_simulated_env_fn from trainer_model_based.py to rl.py

* Remove trainer_model_based imports, clean up player and record_ppo FLAGS.

* Move setup_env and load_t2t_gym_env to T2TGymEnv.

* Correct relative imports.

* Custom policy world_model and data paths for player.

* Enable BatchGymEnv to load directly from checkpoint file.

* Small fix record_ppo.

* Remove unused record_ppo.py.
  • Loading branch information
konradczechowski authored and kpe committed Mar 2, 2019
1 parent 80f1bf3 commit e66a484
Show file tree
Hide file tree
Showing 9 changed files with 613 additions and 337 deletions.
55 changes: 54 additions & 1 deletion tensor2tensor/data_generators/gym_env.py
Expand Up @@ -21,7 +21,10 @@

import collections
import itertools
import os
import random
import re

from gym.spaces import Box
import numpy as np

Expand Down Expand Up @@ -185,7 +188,10 @@ def current_epoch_rollouts(self, split=None, minimal_rollout_frames=0):
if not rollouts_by_split:
if split is not None:
raise ValueError(
"generate_data() should first be called in the current epoch"
"Data is not splitted into train/dev/test. If data created by "
"environment interaction (NOT loaded from disk) you should call "
"generate_data() first. Note that generate_data() will write to "
"disk and can corrupt your experiment data."
)
else:
rollouts = self._current_epoch_rollouts
Expand Down Expand Up @@ -636,6 +642,53 @@ def base_env_name(self):
def num_channels(self):
return self.observation_space.shape[2]

@staticmethod
def infer_last_epoch_num(data_dir):
"""Infer highest epoch number from file names in data_dir."""
names = os.listdir(data_dir)
epochs_str = [re.findall(pattern=r".*\.(-?\d+)$", string=name)
for name in names]
epochs_str = sum(epochs_str, [])
return max([int(epoch_str) for epoch_str in epochs_str])

@staticmethod
def setup_env_from_hparams(hparams, batch_size, max_num_noops):
game_mode = "NoFrameskip-v4"
camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game)
camel_game_name += game_mode
env_name = camel_game_name

env = T2TGymEnv(base_env_name=env_name,
batch_size=batch_size,
grayscale=hparams.grayscale,
resize_width_factor=hparams.resize_width_factor,
resize_height_factor=hparams.resize_height_factor,
rl_env_max_episode_steps=hparams.rl_env_max_episode_steps,
max_num_noops=max_num_noops, maxskip_envs=True)
return env

@staticmethod
def setup_and_load_epoch(hparams, data_dir, which_epoch_data=None):
"""Load T2TBatchGymEnv with data from one epoch.
Args:
which_epoch_data: data from which epoch to load.
"""
t2t_env = T2TGymEnv.setup_env_from_hparams(
hparams, batch_size=hparams.real_batch_size,
max_num_noops=hparams.max_num_noops
)
# Load data.
if which_epoch_data is not None:
if which_epoch_data == "last":
which_epoch_data = T2TGymEnv.infer_last_epoch_num(data_dir)
assert isinstance(which_epoch_data, int), \
"{}".format(type(which_epoch_data))
t2t_env.start_new_epoch(which_epoch_data, data_dir)
else:
t2t_env.start_new_epoch(-999)
return t2t_env

def _derive_observation_space(self, orig_observ_space):
height, width, channels = orig_observ_space.shape
if self.grayscale:
Expand Down
22 changes: 22 additions & 0 deletions tensor2tensor/models/research/rl.py
Expand Up @@ -203,6 +203,28 @@ def env_fn(in_graph):
return env_fn


def make_simulated_env_fn_from_hparams(
real_env, hparams, batch_size, initial_frame_chooser, model_dir,
sim_video_dir=None):
"""Creates a simulated env_fn."""
model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
if hparams.wm_policy_param_sharing:
model_hparams.optimizer_zero_grads = True
return make_simulated_env_fn(
reward_range=real_env.reward_range,
observation_space=real_env.observation_space,
action_space=real_env.action_space,
frame_stack_size=hparams.frame_stack_size,
frame_height=real_env.frame_height, frame_width=real_env.frame_width,
initial_frame_chooser=initial_frame_chooser, batch_size=batch_size,
model_name=hparams.generative_model,
model_hparams=trainer_lib.create_hparams(hparams.generative_model_params),
model_dir=model_dir,
intrinsic_reward_scale=hparams.intrinsic_reward_scale,
sim_video_dir=sim_video_dir,
)


def get_policy(observations, hparams, action_space):
"""Get a policy network.
Expand Down
9 changes: 6 additions & 3 deletions tensor2tensor/rl/envs/simulated_batch_env.py
Expand Up @@ -161,9 +161,12 @@ def initialize(self, sess):
model_loader = tf.train.Saver(
var_list=tf.global_variables(scope="next_frame*") # pylint:disable=unexpected-keyword-arg
)
trainer_lib.restore_checkpoint(
self._model_dir, saver=model_loader, sess=sess, must_restore=True
)
if os.path.isdir(self._model_dir):
trainer_lib.restore_checkpoint(
self._model_dir, saver=model_loader, sess=sess, must_restore=True
)
else:
model_loader.restore(sess=sess, save_path=self._model_dir)

def __str__(self):
return "SimulatedEnv"
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/rl/envs/simulated_batch_gym_env.py
Expand Up @@ -26,8 +26,7 @@


class FlatBatchEnv(Env):
"""TODO(konradczechowski): Add doc-string."""

"""Gym environment interface for Batched Environments (with batch size = 1)"""
def __init__(self, batch_env):
if batch_env.batch_size != 1:
raise ValueError("Number of environments in batch must be equal to one")
Expand Down

0 comments on commit e66a484

Please sign in to comment.