Skip to content

Commit

Permalink
Merge pull request #131 from timmeinhardt/refactor_VecPyTorchFrameStack
Browse files Browse the repository at this point in the history
Refactor VecPyTorchFrameStack
  • Loading branch information
ikostrikov2 committed Sep 17, 2018
2 parents 46b7805 + e7aa3e5 commit 841fad1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
45 changes: 27 additions & 18 deletions envs.py
Expand Up @@ -8,8 +8,8 @@
from baselines import bench
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env import VecEnvWrapper
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize

try:
Expand Down Expand Up @@ -63,9 +63,10 @@ def _thunk():

return _thunk

def make_vec_envs(env_name, seed, num_processes, gamma, log_dir, add_timestep, device, allow_early_resets):
def make_vec_envs(env_name, seed, num_processes, gamma, log_dir, add_timestep,
device, allow_early_resets, num_frame_stack=None):
envs = [make_env(env_name, seed, i, log_dir, add_timestep, allow_early_resets)
for i in range(num_processes)]
for i in range(num_processes)]

if len(envs) > 1:
envs = SubprocVecEnv(envs)
Expand All @@ -80,7 +81,9 @@ def make_vec_envs(env_name, seed, num_processes, gamma, log_dir, add_timestep, d

envs = VecPyTorch(envs, device)

if len(envs.observation_space.shape) == 3:
if num_frame_stack is not None:
envs = VecPyTorchFrameStack(envs, num_frame_stack, device)
elif len(envs.observation_space.shape) == 3:
envs = VecPyTorchFrameStack(envs, 4, device)

return envs
Expand Down Expand Up @@ -140,40 +143,46 @@ def step_async(self, actions):
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
obs = torch.from_numpy(obs).float().to(self.device)
reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
return obs, reward, done, info


# Derived from
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_frame_stack.py
class VecPyTorchFrameStack(VecEnvWrapper):
def __init__(self, venv, nstack, device):
def __init__(self, venv, nstack, device=None):
self.venv = venv
self.nstack = nstack

wos = venv.observation_space # wrapped ob space
self.shape_dim0 = wos.low.shape[0]
self.shape_dim0 = wos.shape[0]

low = np.repeat(wos.low, self.nstack, axis=0)
high = np.repeat(wos.high, self.nstack, axis=0)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape)
self.stackedobs = torch.from_numpy(self.stackedobs).float()
self.stackedobs = self.stackedobs.to(device)
observation_space = gym.spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)

if device is None:
device = torch.device('cpu')
self.stacked_obs = torch.zeros((venv.num_envs,) + low.shape).to(device)

observation_space = gym.spaces.Box(
low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)

def step_wait(self):
obs, rews, news, infos = self.venv.step_wait()
self.stackedobs[:, :-self.shape_dim0] = self.stackedobs[:, self.shape_dim0:]
self.stacked_obs[:, :-self.shape_dim0] = \
self.stacked_obs[:, self.shape_dim0:]
for (i, new) in enumerate(news):
if new:
self.stackedobs[i] = 0
self.stackedobs[:, -self.shape_dim0:] = obs
return self.stackedobs, rews, news, infos
self.stacked_obs[i] = 0
self.stacked_obs[:, -self.shape_dim0:] = obs
return self.stacked_obs, rews, news, infos

def reset(self):
obs = self.venv.reset()
self.stackedobs.fill_(0)
self.stackedobs[:, -self.shape_dim0:] = obs
return self.stackedobs
self.stacked_obs.zero_()
self.stacked_obs[:, -self.shape_dim0:] = obs
return self.stacked_obs

def close(self):
self.venv.close()
18 changes: 12 additions & 6 deletions main.py
Expand Up @@ -88,7 +88,7 @@ def main():
rollouts.to(device)

episode_rewards = deque(maxlen=10)

start = time.time()
for j in range(num_updates):
for step in range(args.num_steps):
Expand All @@ -107,7 +107,8 @@ def main():
episode_rewards.append(info['episode']['r'])

# If done then clean the history of observations.
masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
masks = torch.FloatTensor([[0.0] if done_ else [1.0]
for done_ in done])
rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks)

with torch.no_grad():
Expand Down Expand Up @@ -152,9 +153,12 @@ def main():
np.max(episode_rewards), dist_entropy,
value_loss, action_loss))

if args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0:
eval_envs = make_vec_envs(args.env_name, args.seed + args.num_processes, args.num_processes,
args.gamma, eval_log_dir, args.add_timestep, device, True)
if (args.eval_interval is not None
and len(episode_rewards) > 1
and j % args.eval_interval == 0):
eval_envs = make_vec_envs(
args.env_name, args.seed + args.num_processes, args.num_processes,
args.gamma, eval_log_dir, args.add_timestep, device, True)

if eval_envs.venv.__class__.__name__ == "VecNormalize":
eval_envs.venv.ob_rms = envs.venv.ob_rms
Expand Down Expand Up @@ -183,7 +187,9 @@ def _obfilt(self, obs):

# Obser reward and next obs
obs, reward, done, infos = eval_envs.step(action)
eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])

eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
for done_ in done])
for info in infos:
if 'episode' in info.keys():
eval_episode_rewards.append(info['episode']['r'])
Expand Down

0 comments on commit 841fad1

Please sign in to comment.