Skip to content

Commit

Permalink
Reimplement MDNRNN using new gym. (#253)
Browse files Browse the repository at this point in the history
Summary:
Using our new gym, test MDNRNN feature importance/sensitivity.
Also, train DQN to play POMDP string game with states embedded
with MDNRNN. This is in preparation to nuke old gym folder.
Pull Request resolved: #253

Differential Revision: D21385499

Pulled By: kaiwenw

fbshipit-source-id: a4fa462ecdd5352e4cbb7cbb956517fcdf0f1502
  • Loading branch information
kaiwenw authored and facebook-github-bot committed May 6, 2020
1 parent cd71aca commit c5b5666
Show file tree
Hide file tree
Showing 48 changed files with 1,292 additions and 1,505 deletions.
3 changes: 2 additions & 1 deletion reagent/core/dataclasses.py
Expand Up @@ -41,8 +41,9 @@
pass


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.setLevel(logging.INFO)

logger.info(f"USE_VANILLA_DATACLASS: {USE_VANILLA_DATACLASS}")
logger.info(f"ARBITRARY_TYPES_ALLOWED: {ARBITRARY_TYPES_ALLOWED}")
Expand Down
119 changes: 46 additions & 73 deletions reagent/evaluation/world_model_evaluator.py
Expand Up @@ -4,14 +4,11 @@
from typing import Dict, List

import torch
from reagent.models.mdn_rnn import transpose
from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer
from reagent.types import (
ExtraData,
PreprocessedFeatureVector,
PreprocessedMemoryNetworkInput,
PreprocessedStateAction,
PreprocessedTrainingBatch,
)


Expand All @@ -25,7 +22,7 @@ def __init__(self, trainer: MDNRNNTrainer, state_dim: int) -> None:
self.trainer = trainer
self.state_dim = state_dim

def evaluate(self, tdp: PreprocessedTrainingBatch) -> Dict:
def evaluate(self, tdp: PreprocessedMemoryNetworkInput) -> Dict[str, float]:
self.trainer.mdnrnn.mdnrnn.eval()
losses = self.trainer.get_loss(tdp, state_dim=self.state_dim, batch_first=True)
detached_losses = {
Expand Down Expand Up @@ -65,22 +62,20 @@ def __init__(
self.sorted_action_feature_start_indices = sorted_action_feature_start_indices
self.sorted_state_feature_start_indices = sorted_state_feature_start_indices

def evaluate(self, tdp: PreprocessedTrainingBatch):
def evaluate(self, batch: PreprocessedMemoryNetworkInput):
""" Calculate feature importance: setting each state/action feature to
the mean value and observe loss increase. """
assert isinstance(tdp.training_input, PreprocessedMemoryNetworkInput)

self.trainer.mdnrnn.mdnrnn.eval()

state_features = tdp.training_input.state.float_features
action_features = tdp.training_input.action # type: ignore
batch_size, seq_len, state_dim = state_features.size() # type: ignore
self.trainer.memory_network.mdnrnn.eval()
state_features = batch.state.float_features
action_features = batch.action # type: ignore
seq_len, batch_size, state_dim = state_features.size() # type: ignore
action_dim = action_features.size()[2] # type: ignore
action_feature_num = self.action_feature_num
state_feature_num = self.state_feature_num
feature_importance = torch.zeros(action_feature_num + state_feature_num)

orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True)
orig_losses = self.trainer.get_loss(batch, state_dim=state_dim)
orig_loss = orig_losses["loss"].cpu().detach().item()
del orig_losses

Expand All @@ -90,7 +85,7 @@ def evaluate(self, tdp: PreprocessedTrainingBatch):
state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim]

for i in range(action_feature_num):
action_features = tdp.training_input.action.reshape( # type: ignore
action_features = batch.action.reshape( # type: ignore
(batch_size * seq_len, action_dim)
).data.clone()

Expand All @@ -115,28 +110,24 @@ def evaluate(self, tdp: PreprocessedTrainingBatch):
)

action_features = action_features.reshape( # type: ignore
(batch_size, seq_len, action_dim)
(seq_len, batch_size, action_dim)
) # type: ignore
new_tdp = PreprocessedTrainingBatch(
training_input=PreprocessedMemoryNetworkInput( # type: ignore
state=tdp.training_input.state,
action=action_features,
next_state=tdp.training_input.next_state,
reward=tdp.training_input.reward,
time_diff=torch.ones_like(tdp.training_input.reward).float(),
not_terminal=tdp.training_input.not_terminal, # type: ignore
step=None,
),
extras=ExtraData(),
)
losses = self.trainer.get_loss(
new_tdp, state_dim=state_dim, batch_first=True

new_batch = PreprocessedMemoryNetworkInput(
state=batch.state,
action=action_features,
next_state=batch.next_state,
reward=batch.reward,
time_diff=torch.ones_like(batch.reward).float(),
not_terminal=batch.not_terminal, # type: ignore
step=None,
)
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
del losses

for i in range(state_feature_num):
state_features = tdp.training_input.state.float_features.reshape( # type: ignore
state_features = batch.state.float_features.reshape( # type: ignore
(batch_size * seq_len, state_dim)
).data.clone()
boundary_start, boundary_end = (
Expand All @@ -149,29 +140,24 @@ def evaluate(self, tdp: PreprocessedTrainingBatch):
state_features[:, boundary_start:boundary_end] # type: ignore
)
state_features = state_features.reshape( # type: ignore
(batch_size, seq_len, state_dim)
(seq_len, batch_size, state_dim)
) # type: ignore
new_tdp = PreprocessedTrainingBatch(
training_input=PreprocessedMemoryNetworkInput( # type: ignore
state=PreprocessedFeatureVector(float_features=state_features),
action=tdp.training_input.action, # type: ignore
next_state=tdp.training_input.next_state,
reward=tdp.training_input.reward,
time_diff=torch.ones_like(tdp.training_input.reward).float(),
not_terminal=tdp.training_input.not_terminal, # type: ignore
step=None,
),
extras=ExtraData(),
)
losses = self.trainer.get_loss(
new_tdp, state_dim=state_dim, batch_first=True
new_batch = PreprocessedMemoryNetworkInput(
state=PreprocessedFeatureVector(float_features=state_features),
action=batch.action, # type: ignore
next_state=batch.next_state,
reward=batch.reward,
time_diff=torch.ones_like(batch.reward).float(),
not_terminal=batch.not_terminal, # type: ignore
step=None,
)
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
feature_importance[i + action_feature_num] = (
losses["loss"].cpu().detach().item() - orig_loss
)
del losses

self.trainer.mdnrnn.mdnrnn.train()
self.trainer.memory_network.mdnrnn.train()
logger.info(
"**** Debug tool feature importance ****: {}".format(feature_importance)
)
Expand Down Expand Up @@ -207,44 +193,31 @@ def __init__(
self.state_feature_num = state_feature_num
self.sorted_state_feature_start_indices = sorted_state_feature_start_indices

def evaluate(self, tdp: PreprocessedTrainingBatch):
def evaluate(self, batch: PreprocessedMemoryNetworkInput):
""" Calculate state feature sensitivity due to actions:
randomly permutating actions and see how much the prediction of next
state feature deviates. """
mdnrnn_training_input = tdp.training_input
assert isinstance(mdnrnn_training_input, PreprocessedMemoryNetworkInput)
assert isinstance(batch, PreprocessedMemoryNetworkInput)

self.trainer.mdnrnn.mdnrnn.eval()
self.trainer.memory_network.mdnrnn.eval()

batch_size, seq_len, state_dim = (
mdnrnn_training_input.next_state.float_features.size()
)
seq_len, batch_size, state_dim = batch.next_state.float_features.size()
state_feature_num = self.state_feature_num
feature_sensitivity = torch.zeros(state_feature_num)

state, action, next_state, reward, not_terminal = transpose(
mdnrnn_training_input.state.float_features,
mdnrnn_training_input.action,
mdnrnn_training_input.next_state.float_features,
mdnrnn_training_input.reward,
mdnrnn_training_input.not_terminal,
)
mdnrnn_input = PreprocessedStateAction(
state=PreprocessedFeatureVector(float_features=state),
action=PreprocessedFeatureVector(float_features=action),
)
# the input of mdnrnn has seq-len as the first dimension
mdnrnn_output = self.trainer.mdnrnn(mdnrnn_input)
state = batch.state.float_features
action = batch.action
mdnrnn_input = PreprocessedStateAction.from_tensors(state, action)

# the input of world_model has seq-len as the first dimension
mdnrnn_output = self.trainer.memory_network(mdnrnn_input)
predicted_next_state_means = mdnrnn_output.mus

shuffled_mdnrnn_input = PreprocessedStateAction(
state=PreprocessedFeatureVector(float_features=state),
# shuffle the actions
action=PreprocessedFeatureVector(
float_features=action[:, torch.randperm(batch_size), :]
),
# shuffle the actions
shuffled_mdnrnn_input = PreprocessedStateAction.from_tensors(
state, action[:, torch.randperm(batch_size), :]
)
shuffled_mdnrnn_output = self.trainer.mdnrnn(shuffled_mdnrnn_input)
shuffled_mdnrnn_output = self.trainer.memory_network(shuffled_mdnrnn_input)
shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus

assert (
Expand Down Expand Up @@ -274,7 +247,7 @@ def evaluate(self, tdp: PreprocessedTrainingBatch):
)
feature_sensitivity[i] = abs_diff.cpu().detach().item()

self.trainer.mdnrnn.mdnrnn.train()
self.trainer.memory_network.mdnrnn.train()
logger.info(
"**** Debug tool feature sensitivity ****: {}".format(feature_sensitivity)
)
Expand Down
56 changes: 32 additions & 24 deletions reagent/gym/agents/post_step.py
Expand Up @@ -2,13 +2,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import inspect
import logging
from typing import Any, Optional, Union
from typing import Optional, Union

import gym
import numpy as np
import reagent.types as rlt
import torch
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
from reagent.gym.types import PostStep
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.training.rl_dataset import RLDataset
Expand All @@ -18,14 +19,38 @@
logger = logging.getLogger(__name__)


def add_replay_buffer_post_step(replay_buffer: ReplayBuffer):
"""
Simply add transitions to replay_buffer.
"""

def post_step(
obs: np.ndarray,
actor_output: rlt.ActorOutput,
reward: float,
terminal: bool,
possible_actions_mask: Optional[torch.Tensor],
) -> None:
action = actor_output.action.numpy()
log_prob = actor_output.log_prob.numpy()
if possible_actions_mask is None:
possible_actions_mask = torch.ones_like(actor_output.action).to(torch.bool)
possible_actions_mask = possible_actions_mask.numpy()
replay_buffer.add(
obs, action, reward, terminal, possible_actions_mask, log_prob.item()
)

return post_step


def train_with_replay_buffer_post_step(
replay_buffer: ReplayBuffer,
trainer: Trainer,
training_freq: int,
batch_size: int,
replay_burnin: Optional[int] = None,
trainer_preprocessor=None,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = "cpu",
) -> PostStep:
""" Called in post_step of agent to train based on replay buffer (RB).
Args:
Expand All @@ -36,7 +61,7 @@ def train_with_replay_buffer_post_step(
replay_burnin: optional requirement for minimum size of RB before
training begins. (i.e. burn in this many frames)
"""
if device is not None and isinstance(device, str):
if isinstance(device, str):
device = torch.device(device)

_num_steps = 0
Expand All @@ -45,27 +70,10 @@ def train_with_replay_buffer_post_step(
size_req = max(size_req, replay_burnin)

if trainer_preprocessor is None:
sig = inspect.signature(trainer.train)
logger.info(f"Deriving trainer_preprocessor from {sig.parameters}")
# Assuming training_batch is in the first position (excluding self)
assert (
list(sig.parameters.keys())[0] == "training_batch"
), f"{sig.parameters} doesn't have training batch in first position."
training_batch_type = sig.parameters["training_batch"].annotation
assert training_batch_type != inspect.Parameter.empty
if not hasattr(training_batch_type, "from_replay_buffer"):
raise NotImplementedError(
f"{training_batch_type} does not implement from_replay_buffer"
)

def trainer_preprocessor(batch):
retval = training_batch_type.from_replay_buffer(batch)
if device is not None:
retval = retval.to(device)
return retval
trainer_preprocessor = make_replay_buffer_trainer_preprocessor(trainer, device)

def post_step(
obs: Any,
obs: np.ndarray,
actor_output: rlt.ActorOutput,
reward: float,
terminal: bool,
Expand Down Expand Up @@ -98,7 +106,7 @@ def log_data_post_step(dataset: RLDataset, mdp_id: str, env: gym.Env) -> PostSte
sequence_number = 0

def post_step(
obs: Any,
obs: np.ndarray,
actor_output: rlt.ActorOutput,
reward: float,
terminal: bool,
Expand Down
18 changes: 18 additions & 0 deletions reagent/gym/envs/__init__.py
@@ -1,7 +1,25 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .dynamics.linear_dynamics import LinDynaEnv # noqa
from .env_factory import EnvFactory
from .pomdp.pocman import PocManEnv # noqa
from .pomdp.string_game import StringGameEnv # noqa
from .utils import register_if_not_exists


__all__ = ["EnvFactory"]


######### Register classes below ##########

CUR_MODULE = "reagent.gym.envs"
ENV_CLASSES = [
("Pocman-v0", ".pomdp.pocman:PocManEnv"),
("StringGame-v0", ".pomdp.string_game:StringGameEnv"),
("LinearDynamics-v0", ".dynamics.linear_dynamics:LinDynaEnv"),
]

for env_name, rel_module_path in ENV_CLASSES:
full_module_path = CUR_MODULE + rel_module_path
register_if_not_exists(id=env_name, entry_point=full_module_path)
Expand Up @@ -67,8 +67,10 @@ def step(self, action):
# add the negative sign because we actually want to maximize the rewards, while an LRQ solution minimizes
# rewards by convention
reward = -(
state.T.dot(self.Q).dot(state) + action.T.dot(self.R).dot(action)
).squeeze()
(
state.T.dot(self.Q).dot(state) + action.T.dot(self.R).dot(action)
).squeeze()
)
self.step_cnt += 1
terminal = False
if self.step_cnt >= self.max_steps:
Expand Down
Expand Up @@ -216,7 +216,7 @@ def __init__(self):
self.observation_space = Box(low=0, high=1, shape=(STATE_DIM,))
self._reward_range = 100
self.step_cnt = 0
self.max_step = self.board["_max_step"]
self._max_episode_steps = self.board["_max_step"]

def seed(self, seed=None):
np.random.seed(seed)
Expand Down

0 comments on commit c5b5666

Please sign in to comment.