Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions rlpyt/agents/pg/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,34 @@
RecurrentGaussianPgAgent, AlternatingRecurrentGaussianPgAgent)
from rlpyt.models.pg.mujoco_ff_model import MujocoFfModel
from rlpyt.models.pg.mujoco_lstm_model import MujocoLstmModel
from rlpyt.utils.buffer import buffer_to


class MujocoMixin:
"""
Mixin class defining which environment interface properties
are given to the model.
Now supports observation normalization, including multi-GPU.
"""
_ddp = False # Sets True if data parallel, for normalized obs

def make_env_to_model_kwargs(self, env_spaces):
"""Extract observation_shape and action_size."""
assert len(env_spaces.action.shape) == 1
return dict(observation_shape=env_spaces.observation.shape,
action_size=env_spaces.action.shape[0])

def update_obs_rms(self, observation):
observation = buffer_to(observation, device=self.device)
if self._ddp:
self.model.module.update_obs_rms(observation)
else:
self.model.update_obs_rms(observation)

def data_parallel(self, *args, **kwargs):
super().data_parallel(*args, **kwargs)
self._ddp = True


class MujocoFfAgent(MujocoMixin, GaussianPgAgent):

Expand Down
3 changes: 3 additions & 0 deletions rlpyt/algos/pg/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def optimize_agent(self, itr, samples):
"""
Train the agent on input samples, by one gradient step.
"""
if hasattr(self.agent, "update_obs_rms"):
# NOTE: suboptimal--obs sent to device here and in agent(*inputs).
self.agent.update_obs_rms(samples.env.observation)
self.optimizer.zero_grad()
loss, entropy, perplexity = self.loss(samples)
loss.backward()
Expand Down
2 changes: 2 additions & 0 deletions rlpyt/algos/pg/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def optimize_agent(self, itr, samples):
prev_reward=samples.env.prev_reward,
)
agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
if hasattr(self.agent, "update_obs_rms"):
self.agent.update_obs_rms(agent_inputs.observation)
return_, advantage, valid = self.process_returns(samples)
loss_inputs = LossInputs( # So can slice all.
agent_inputs=agent_inputs,
Expand Down
2 changes: 2 additions & 0 deletions rlpyt/envs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def step(self, action):
else:
info["timeout"] = False
info = info_to_nt(info)
if isinstance(r, float):
r = np.dtype("float32").type(r) # Scalar float32.
return EnvStep(obs, r, d, info)

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion rlpyt/experiments/configs/mujoco/pg/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
normalize_advantage=True,
),
env=dict(id="Hopper-v3"),
model=dict(),
model=dict(normalize_observation=False),
optim=dict(),
runner=dict(
n_steps=1e6,
Expand Down
4 changes: 2 additions & 2 deletions rlpyt/experiments/configs/mujoco/pg/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
ratio_clip=0.2,
normalize_advantage=True,
linear_lr_schedule=True,
bootstrap_timelimit=False,
# bootstrap_timelimit=False,
),
env=dict(id="Hopper-v3"),
model=dict(),
model=dict(normalize_observation=False),
optim=dict(),
runner=dict(
n_steps=1e6,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
keys = [("env", "id")]
variant_levels.append(VariantLevel(keys, values, dir_names))

norm_obs = [True]
values = list(zip(norm_obs))
dir_names = ["TrueNormObs"]
keys = [("model", "normalize_observation")]
variant_levels.append(VariantLevel(keys, values, dir_names))

variants, log_dirs = make_variants(*variant_levels)

run_experiments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
keys = [("env", "id")]
variant_levels.append(VariantLevel(keys, values, dir_names))

norm_obs = [True]
values = list(zip(norm_obs))
dir_names = ["TrueNormObs"]
keys = [("model", "normalize_observation")]
variant_levels.append(VariantLevel(keys, values, dir_names))

variants, log_dirs = make_variants(*variant_levels)

run_experiments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
affinity_code = encode_affinity(
n_cpu_core=2,
n_gpu=0,
hyperthread_offset=2,
hyperthread_offset=4,
n_socket=1,
cpu_per_run=2,
)
Expand All @@ -22,6 +22,12 @@
keys = [("env", "id")]
variant_levels.append(VariantLevel(keys, values, dir_names))

norm_obs = [True]
values = list(zip(norm_obs))
dir_names = ["TrueNormObs"]
keys = [("model", "normalize_observation")]
variant_levels.append(VariantLevel(keys, values, dir_names))

variants, log_dirs = make_variants(*variant_levels)

run_experiments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import sys

from rlpyt.utils.launching.affinity import affinity_from_code
from rlpyt.samplers.cpu.parallel_sampler import CpuParallelSampler
from rlpyt.samplers.cpu.collectors import ResetCollector
from rlpyt.samplers.parallel.cpu.sampler import CpuSampler
from rlpyt.samplers.parallel.cpu.collectors import CpuResetCollector
from rlpyt.envs.gym import make as gym_make
from rlpyt.algos.pg.a2c import A2C
from rlpyt.agents.pg.mujoco import MujocoFfAgent
Expand All @@ -20,10 +20,10 @@ def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
variant = load_variant(log_dir)
config = update_config(config, variant)

sampler = CpuParallelSampler(
sampler = CpuSampler(
EnvCls=gym_make,
env_kwargs=config["env"],
CollectorCls=ResetCollector,
CollectorCls=CpuResetCollector,
**config["sampler"]
)
algo = A2C(optim_kwargs=config["optim"], **config["algo"])
Expand Down
20 changes: 20 additions & 0 deletions rlpyt/models/pg/mujoco_ff_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims
from rlpyt.models.mlp import MlpModel
from rlpyt.models.running_mean_std import RunningMeanStdModel


class MujocoFfModel(torch.nn.Module):
Expand All @@ -21,6 +22,9 @@ def __init__(
hidden_nonlinearity=torch.nn.Tanh, # Module form.
mu_nonlinearity=torch.nn.Tanh, # Module form.
init_log_std=0.,
normalize_observation=False,
norm_obs_clip=10,
norm_obs_var_clip=1e-6,
):
"""Instantiate neural net modules according to inputs."""
super().__init__()
Expand All @@ -44,6 +48,11 @@ def __init__(
nonlinearity=hidden_nonlinearity,
)
self.log_std = torch.nn.Parameter(init_log_std * torch.ones(action_size))
if normalize_observation:
self.obs_rms = RunningMeanStdModel(observation_shape)
self.norm_obs_clip = norm_obs_clip
self.norm_obs_var_clip = norm_obs_var_clip
self.normalize_observation = normalize_observation

def forward(self, observation, prev_action, prev_reward):
"""
Expand All @@ -56,6 +65,13 @@ def forward(self, observation, prev_action, prev_reward):
# Infer (presence of) leading dimensions: [T,B], [B], or [].
lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)

if self.normalize_observation:
obs_var = self.obs_rms.var
if self.norm_obs_var_clip is not None:
obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
observation = torch.clamp((observation - self.obs_rms.mean) /
obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip)

obs_flat = observation.view(T * B, -1)
mu = self.mu(obs_flat)
v = self.v(obs_flat).squeeze(-1)
Expand All @@ -65,3 +81,7 @@ def forward(self, observation, prev_action, prev_reward):
mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B)

return mu, log_std, v

def update_obs_rms(self, observation):
if self.normalize_observation:
self.obs_rms.update(observation)
28 changes: 22 additions & 6 deletions rlpyt/models/pg/mujoco_lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims
from rlpyt.models.mlp import MlpModel
from rlpyt.models.running_mean_std import RunningMeanStdModel
from rlpyt.utils.collections import namedarraytuple

RnnState = namedarraytuple("RnnState", ["h", "c"])
Expand All @@ -21,6 +22,9 @@ def __init__(
hidden_sizes=None, # None for default (see below).
lstm_size=256,
nonlinearity=torch.nn.ReLU,
normalize_observation=False,
norm_obs_clip=10,
norm_obs_var_clip=1e-6,
):
super().__init__()
self._obs_n_dim = len(observation_shape)
Expand All @@ -36,6 +40,13 @@ def __init__(
mlp_output_size = hidden_sizes[-1] if hidden_sizes else mlp_input_size
self.lstm = torch.nn.LSTM(mlp_output_size + action_size + 1, lstm_size)
self.head = torch.nn.Linear(lstm_size, action_size * 2 + 1)
if normalize_observation:
self.obs_rms = RunningMeanStdModel(observation_shape)
self.norm_obs_clip = norm_obs_clip
self.norm_obs_var_clip = norm_obs_var_clip
self.normalize_observation = normalize_observation



def forward(self, observation, prev_action, prev_reward, init_rnn_state):
"""
Expand All @@ -46,15 +57,16 @@ def forward(self, observation, prev_action, prev_reward, init_rnn_state):
not given. Used both in sampler and in algorithm (both via the agent).
Also returns the next RNN state.
"""



"""Feedforward layers process as [T*B,H]. Return same leading dims as
input, can be [T,B], [B], or []."""

# Infer (presence of) leading dimensions: [T,B], [B], or [].
lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim)

if self.normalize_observation:
obs_var = self.obs_rms.var
if self.norm_obs_var_clip is not None:
obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
observation = torch.clamp((observation - self.obs_rms.mean) /
obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip)

mlp_out = self.mlp(observation.view(T * B, -1))
lstm_input = torch.cat([
mlp_out.view(T, B, -1),
Expand All @@ -74,3 +86,7 @@ def forward(self, observation, prev_action, prev_reward, init_rnn_state):
next_rnn_state = RnnState(h=hn, c=cn)

return mu, log_std, v, next_rnn_state

def update_obs_rms(self, observation):
if self.normalize_observation:
self.obs_rms.update(observation)
45 changes: 45 additions & 0 deletions rlpyt/models/running_mean_std.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

import torch
import torch.distributed as dist
from rlpyt.utils.tensor import infer_leading_dims


class RunningMeanStdModel(torch.nn.Module):

"""Adapted from OpenAI baselines. Maintains a running estimate of mean
and variance of data along each dimension, accessible in the `mean` and
`var` attributes. Supports multi-GPU training by all-reducing statistics
across GPUs."""

def __init__(self, shape):
super().__init__()
self.register_buffer("mean", torch.zeros(shape))
self.register_buffer("var", torch.ones(shape))
self.register_buffer("count", torch.zeros(()))
self.shape = shape

def update(self, x):
_, T, B, _ = infer_leading_dims(x, len(self.shape))
x = x.view(T * B, *self.shape)
batch_mean = x.mean(dim=0)
batch_var = x.var(dim=0, unbiased=False)
batch_count = T * B
if dist.is_initialized(): # Assume need all-reduce.
mean_var = torch.stack([batch_mean, batch_var])
dist.all_reduce(mean_var)
world_size = dist.get_world_size()
mean_var /= world_size
batch_count *= world_size
batch_mean, batch_var = mean_var[0], mean_var[1]
if self.count == 0:
self.mean[:] = batch_mean
self.var[:] = batch_var
else:
delta = batch_mean - self.mean
total = self.count + batch_count
self.mean[:] = self.mean + delta * batch_count / total
m_a = self.var * self.count
m_b = batch_var * batch_count
M2 = m_a + m_b + delta ** 2 * self.count * batch_count / total
self.var[:] = M2 / total
self.count += batch_count
5 changes: 3 additions & 2 deletions rlpyt/runners/minibatch_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def get_n_itr(self):
# Log at least as often as requested (round down itrs):
log_interval_itrs = max(self.log_interval_steps //
self.itr_batch_size, 1)
# FIXME: To run at least as many steps as requested, round up log interval?
n_itr = math.ceil(self.n_steps / self.log_interval_steps) * log_interval_itrs
n_itr = self.n_steps // self.itr_batch_size
if n_itr % log_interval_itrs > 0: # Keep going to next log itr.
n_itr += log_interval_itrs - (n_itr % log_interval_itrs)
self.log_interval_itrs = log_interval_itrs
self.n_itr = n_itr
logger.log(f"Running {n_itr} iterations of minibatch RL.")
Expand Down