Skip to content

Commit

Permalink
dd VecNormalize with eval mode and fix previous bug
Browse files Browse the repository at this point in the history
  • Loading branch information
timmeinhardt committed Sep 17, 2018
1 parent 841fad1 commit 19f6bfa
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 42 deletions.
34 changes: 7 additions & 27 deletions enjoy.py
@@ -1,13 +1,12 @@
import argparse
import os
import types

import numpy as np
import torch

from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize
from envs import VecPyTorch, make_vec_envs
from utils import get_render_func, get_vec_normalize


parser = argparse.ArgumentParser(description='RL')
parser.add_argument('--seed', type=int, default=1,
Expand All @@ -26,35 +25,16 @@
None, None, args.add_timestep, device='cpu')

# Get a render function
render_func = None
tmp_env = env
while True:
if hasattr(tmp_env, 'envs'):
render_func = tmp_env.envs[0].render
break
elif hasattr(tmp_env, 'venv'):
tmp_env = tmp_env.venv
elif hasattr(tmp_env, 'env'):
tmp_env = tmp_env.env
else:
break
render_func = get_render_func(env)

# We need to use the same statistics for normalization as used in training
actor_critic, ob_rms = \
torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))

if isinstance(env.venv, VecNormalize):
env.venv.ob_rms = ob_rms

# An ugly hack to remove updates
def _obfilt(self, obs):
if self.ob_rms:
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs

env.venv._obfilt = types.MethodType(_obfilt, env.venv)
vec_norm = get_vec_normalize(envs)
if vec_norm is not None:
vec_norm.eval()
vec_norm.ob_rms = ob_rms

recurrent_hidden_states = torch.zeros(1, actor_critic.recurrent_hidden_state_size)
masks = torch.zeros(1, 1)
Expand Down
25 changes: 24 additions & 1 deletion envs.py
Expand Up @@ -10,7 +10,8 @@
from baselines.common.vec_env import VecEnvWrapper
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common.vec_env.vec_normalize import VecNormalize as VecNormalize_


try:
import dm_control2gym
Expand Down Expand Up @@ -147,6 +148,28 @@ def step_wait(self):
return obs, reward, done, info


class VecNormalize(VecNormalize_):

def __init__(self, *args, **kwargs):
super(VecNormalize, self).__init__(*args, **kwargs)
self.training = True

def _obfilt(self, obs):
if self.ob_rms:
if self.training:
self.ob_rms.update(obs)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs

def train(self):
self.training = True

def eval(self):
self.training = False


# Derived from
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_frame_stack.py
class VecPyTorchFrameStack(VecEnvWrapper):
Expand Down
20 changes: 6 additions & 14 deletions main.py
Expand Up @@ -2,7 +2,6 @@
import glob
import os
import time
import types
from collections import deque

import gym
Expand All @@ -17,6 +16,7 @@
from envs import make_vec_envs
from model import Policy
from storage import RolloutStorage
from utils import get_vec_normalize
from visualize import visdom_plot

args = get_args()
Expand Down Expand Up @@ -135,7 +135,7 @@ def main():
save_model = copy.deepcopy(actor_critic).cpu()

save_model = [save_model,
hasattr(envs.venv, 'ob_rms') and envs.venv.ob_rms or None]
getattr(get_vec_normalize(envs), 'ob_rms', None)]

torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))

Expand All @@ -160,18 +160,10 @@ def main():
args.env_name, args.seed + args.num_processes, args.num_processes,
args.gamma, eval_log_dir, args.add_timestep, device, True)

if eval_envs.venv.__class__.__name__ == "VecNormalize":
eval_envs.venv.ob_rms = envs.venv.ob_rms

# An ugly hack to remove updates
def _obfilt(self, obs):
if self.ob_rms:
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs

eval_envs.venv._obfilt = types.MethodType(_obfilt, envs.venv)
vec_norm = get_vec_normalize(eval_envs)
if vec_norm is not None:
vec_norm.eval()
vec_norm.ob_rms = get_vec_normalize(envs).ob_rms

eval_episode_rewards = []

Expand Down
23 changes: 23 additions & 0 deletions utils.py
@@ -1,6 +1,29 @@
import torch
import torch.nn as nn

from envs import VecNormalize


# Get a render function
def get_render_func(venv):
if hasattr(venv, 'envs'):
return venv.envs[0].render
elif hasattr(venv, 'venv'):
return get_render_func(venv.venv)
elif hasattr(venv, 'env'):
return get_render_func(venv.env)

return None


def get_vec_normalize(venv):
if isinstance(venv, VecNormalize):
return venv
elif hasattr(venv, 'venv'):
return get_vec_normalize(venv.venv)

return None


# Necessary for my KFAC implementation.
class AddBias(nn.Module):
Expand Down

0 comments on commit 19f6bfa

Please sign in to comment.