Skip to content

Commit

Permalink
Optimizing replay memory
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #175

Reviewed By: czxttkl

Differential Revision: D17974234

fbshipit-source-id: 3f54b759e669d534bba70ff2ea6d0d332b38c15c
  • Loading branch information
kittipatv authored and facebook-github-bot committed Oct 17, 2019
1 parent adfa6ab commit be985fb
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 62 deletions.
247 changes: 191 additions & 56 deletions ml/rl/test/gym/open_ai_gym_memory_pool.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import dataclasses
import logging
import random
from typing import Optional
Expand All @@ -16,25 +17,142 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class MemoryBuffer:
state: torch.Tensor
action: torch.Tensor
reward: torch.Tensor
next_state: torch.Tensor
next_action: torch.Tensor
terminal: torch.Tensor
possible_next_actions: Optional[torch.Tensor]
possible_next_actions_mask: Optional[torch.Tensor]
possible_actions: Optional[torch.Tensor]
possible_actions_mask: Optional[torch.Tensor]
time_diff: torch.Tensor
policy_id: torch.Tensor

@torch.no_grad() # type: ignore
def slice(self, indices):
return MemoryBuffer(
state=self.state[indices],
action=self.action[indices],
reward=self.reward[indices],
next_state=self.next_state[indices],
next_action=self.next_action[indices],
terminal=self.terminal[indices],
possible_next_actions=self.possible_next_actions[indices]
if self.possible_next_actions is not None
else None,
possible_next_actions_mask=self.possible_next_actions_mask[indices]
if self.possible_next_actions_mask is not None
else None,
possible_actions=self.possible_actions[indices]
if self.possible_actions is not None
else None,
possible_actions_mask=self.possible_actions_mask[indices]
if self.possible_actions_mask is not None
else None,
time_diff=self.time_diff[indices],
policy_id=self.policy_id[indices],
)

@torch.no_grad() # type: ignore
def insert_at(
self,
idx: int,
state: torch.Tensor,
action: torch.Tensor,
reward: float,
next_state: torch.Tensor,
next_action: torch.Tensor,
terminal: bool,
possible_next_actions: Optional[torch.Tensor],
possible_next_actions_mask: Optional[torch.Tensor],
time_diff: float,
possible_actions: Optional[torch.Tensor],
possible_actions_mask: Optional[torch.Tensor],
policy_id: int,
):
self.state[idx] = state
self.action[idx] = action
self.reward[idx] = reward
self.next_state[idx] = next_state
self.next_action[idx] = next_action
self.terminal[idx] = terminal
if self.possible_actions is not None:
self.possible_actions[idx] = possible_actions
if self.possible_actions_mask is not None:
self.possible_actions_mask[idx] = possible_actions_mask
if self.possible_next_actions is not None:
self.possible_next_actions[idx] = possible_next_actions
if self.possible_next_actions_mask is not None:
self.possible_next_actions_mask[idx] = possible_next_actions_mask
self.time_diff[idx] = time_diff
self.policy_id[idx] = policy_id

@classmethod
def create(
cls,
max_size: int,
state_dim: int,
action_dim: int,
max_possible_actions: Optional[int],
has_possble_actions: bool,
):
return cls(
state=torch.zeros((max_size, state_dim)),
action=torch.zeros((max_size, action_dim)),
reward=torch.zeros((max_size, 1)),
next_state=torch.zeros((max_size, state_dim)),
next_action=torch.zeros((max_size, action_dim)),
terminal=torch.zeros((max_size, 1), dtype=torch.uint8),
possible_next_actions=torch.zeros(
(max_size, max_possible_actions, action_dim)
)
if has_possble_actions
else None,
possible_next_actions_mask=torch.zeros((max_size, max_possible_actions))
if max_possible_actions
else None,
possible_actions=torch.zeros((max_size, max_possible_actions, action_dim))
if has_possble_actions
else None,
possible_actions_mask=torch.zeros((max_size, max_possible_actions))
if max_possible_actions
else None,
time_diff=torch.zeros((max_size, 1)),
policy_id=torch.zeros((max_size, 1), dtype=torch.long),
)


class OpenAIGymMemoryPool:
def __init__(self, max_replay_memory_size):
def __init__(self, max_replay_memory_size: int):
"""
Creates an OpenAIGymMemoryPool object.
:param max_replay_memory_size: Upper bound on the number of transitions
to store in replay memory.
"""
self.replay_memory = []
self.max_replay_memory_size = max_replay_memory_size
self.memory_num = 0
self.skip_insert_until = self.max_replay_memory_size

# Not initializing in the beginning because we don't know the shapes
self.memory_buffer: Optional[MemoryBuffer] = None

@property
def size(self):
return len(self.replay_memory)
return min(self.memory_num, self.max_replay_memory_size)

@property
def state_dim(self):
assert self.memory_buffer is not None
return self.memory_buffer.state.shape[1]

def shuffle(self):
random.shuffle(self.replay_memory)
@property
def action_dim(self):
assert self.memory_buffer is not None
return self.memory_buffer.action.shape[1]

def sample_memories(self, batch_size, model_type, chunk=None):
"""
Expand All @@ -49,72 +167,63 @@ def sample_memories(self, batch_size, model_type, chunk=None):
:param model_type: Model type (discrete, parametric).
:param chunk: Index of chunk of data (for deterministic sampling).
"""
cols = [[], [], [], [], [], [], [], [], [], [], [], []]

if chunk is None:
indices = np.random.randint(0, len(self.replay_memory), size=batch_size)
indices = torch.randint(0, self.size, size=(batch_size,))
else:
start_idx = chunk * batch_size
end_idx = start_idx + batch_size
indices = range(start_idx, end_idx)

for idx in indices:
memory = self.replay_memory[idx]
for col, value in zip(cols, memory):
col.append(value)
memory = self.memory_buffer.slice(indices)

states = stack(cols[0])
next_states = stack(cols[3])
states = memory.state
next_states = memory.next_state

assert states.dim() == 2
assert next_states.dim() == 2

if model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value:
num_possible_actions = len(cols[7][0])
num_possible_actions = memory.possible_actions_mask.shape[1]

actions = stack(cols[1])
next_actions = stack(cols[4])
actions = memory.action
next_actions = memory.next_action

tiled_states = states.repeat(1, num_possible_actions).reshape(
-1, states.shape[1]
)
possible_actions = torch.cat(cols[8])
possible_actions = memory.possible_actions.reshape(-1, actions.shape[1])
possible_actions_state_concat = torch.cat(
(tiled_states, possible_actions), dim=1
)
possible_actions_mask = stack(cols[9])
possible_actions_mask = memory.possible_actions_mask

tiled_next_states = next_states.repeat(1, num_possible_actions).reshape(
-1, next_states.shape[1]
)
possible_next_actions = torch.cat(cols[6])
possible_next_actions = memory.possible_next_actions.reshape(
-1, actions.shape[1]
)
possible_next_actions_state_concat = torch.cat(
(tiled_next_states, possible_next_actions), dim=1
)
possible_next_actions_mask = stack(cols[7])
possible_next_actions_mask = memory.possible_next_actions_mask
else:
possible_actions = None
possible_actions_state_concat = None
possible_next_actions = None
possible_next_actions_state_concat = None
if cols[7] is None or cols[7][0] is None:
possible_next_actions_mask = None
else:
possible_next_actions_mask = stack(cols[7])
if cols[9] is None or cols[9][0] is None:
possible_actions_mask = None
else:
possible_actions_mask = stack(cols[9])
possible_next_actions_mask = memory.possible_next_actions_mask
possible_actions_mask = memory.possible_actions_mask

actions = stack(cols[1])
next_actions = stack(cols[4])
actions = memory.action
next_actions = memory.next_action

assert len(actions.size()) == 2
assert len(next_actions.size()) == 2

rewards = torch.tensor(cols[2], dtype=torch.float32).reshape(-1, 1)
not_terminal = (1 - torch.tensor(cols[5], dtype=torch.int32)).reshape(-1, 1)
time_diffs = torch.tensor(cols[10], dtype=torch.int32).reshape(-1, 1)
rewards = memory.reward
not_terminal = 1 - memory.terminal
time_diffs = memory.time_diff

return TrainingDataPage(
states=states,
Expand Down Expand Up @@ -144,32 +253,58 @@ def insert_into_memory(
time_diff: float,
possible_actions: Optional[torch.Tensor],
possible_actions_mask: Optional[torch.Tensor],
policy_id: str,
policy_id: int,
):
"""
Inserts transition into replay memory in such a way that retrieving
transitions uniformly at random will be equivalent to reservoir sampling.
"""
item = (
state,
action,
reward,
next_state,
next_action,
terminal,
possible_next_actions,
possible_next_actions_mask,
possible_actions,
possible_actions_mask,
time_diff,
policy_id,
)

if self.memory_buffer is None:
assert state.shape == next_state.shape
assert len(state.shape) == 1
assert action.shape == next_action.shape
assert len(action.shape) == 1
if possible_actions_mask is not None:
assert possible_next_actions_mask is not None
assert possible_actions_mask.shape == possible_next_actions_mask.shape
assert len(possible_actions_mask.shape) == 1
max_possible_actions = possible_actions_mask.shape[0]
else:
max_possible_actions = None

assert (possible_actions is not None) == (possible_next_actions is not None)

self.memory_buffer = MemoryBuffer.create(
max_size=self.max_replay_memory_size,
state_dim=state.shape[0],
action_dim=action.shape[0],
max_possible_actions=max_possible_actions,
has_possble_actions=possible_actions is not None,
)

insert_idx = None
if self.memory_num < self.max_replay_memory_size:
self.replay_memory.append(item)
elif self.memory_num >= self.skip_insert_until:
p = float(self.max_replay_memory_size) / self.memory_num
self.skip_insert_until += np.random.geometric(p)
rand_index = np.random.randint(self.max_replay_memory_size)
self.replay_memory[rand_index] = item
insert_idx = self.memory_num
else:
rand_idx = torch.randint(0, self.memory_num, size=(1,)).item()
if rand_idx < self.max_replay_memory_size:
insert_idx = rand_idx # type: ignore

if insert_idx is not None:
self.memory_buffer.insert_at(
insert_idx,
state,
action,
reward,
next_state,
next_action,
terminal,
possible_next_actions,
possible_next_actions_mask,
time_diff,
possible_actions,
possible_actions_mask,
policy_id,
)
self.memory_num += 1
6 changes: 3 additions & 3 deletions ml/rl/test/gym/run_gym.py
Expand Up @@ -123,8 +123,8 @@ def create_replay_buffer(
replay_buffer = OpenAIGymMemoryPool(params.max_replay_memory_size)
if path_to_pickled_transitions:
create_stored_policy_offline_dataset(replay_buffer, path_to_pickled_transitions)
replay_state_dim = replay_buffer.replay_memory[0][0].shape[0]
replay_action_dim = replay_buffer.replay_memory[0][1].shape[0]
replay_state_dim = replay_buffer.state_dim
replay_action_dim = replay_buffer.action_dim
assert replay_state_dim == env.state_dim
assert replay_action_dim == env.action_dim
elif offline_train:
Expand Down Expand Up @@ -490,7 +490,7 @@ def train_gym_online_rl(
if (
total_timesteps % train_every_ts == 0
and total_timesteps > train_after_ts
and len(replay_buffer.replay_memory) >= trainer.minibatch_size
and replay_buffer.size >= trainer.minibatch_size
and not (stop_training_after_solved and solved)
):
for _ in range(num_train_batches):
Expand Down
2 changes: 1 addition & 1 deletion ml/rl/test/gym/world_model/mdnrnn_gym.py
Expand Up @@ -454,7 +454,7 @@ def concat_batch(batch):
dataset.insert(
state=state_embed,
action=torch.tensor(action_batch[i][hidden_idx + 1]), # type: ignore
reward=reward_batch[i][hidden_idx + 1], # type: ignore
reward=float(reward_batch[i][hidden_idx + 1]), # type: ignore
next_state=next_state_embed,
next_action=torch.tensor(
next_action_batch[i][next_hidden_idx + 1] # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion ml/rl/test/gym/world_model/state_embed_gym.py
Expand Up @@ -220,7 +220,8 @@ def run_gym(
for row in embed_rl_dataset.rows:
replay_buffer.insert_into_memory(**row)

state_mem = torch.cat([m[0] for m in replay_buffer.replay_memory])
assert replay_buffer.memory_buffer is not None
state_mem = replay_buffer.memory_buffer.state
state_min_value = torch.min(state_mem).item()
state_max_value = torch.max(state_mem).item()
state_embed_env = StateEmbedGymEnvironment(
Expand Down
6 changes: 6 additions & 0 deletions ml/rl/workflow/dqn_workflow.py
Expand Up @@ -13,6 +13,7 @@
from ml.rl.json_serialize import from_json
from ml.rl.parameters import (
DiscreteActionModelParameters,
EvaluationParameters,
NormalizationParameters,
RainbowDQNParameters,
RLParameters,
Expand Down Expand Up @@ -117,12 +118,17 @@ def single_process_main(gpu_index, *args):
rl_parameters = from_json(params["rl"], RLParameters)
training_parameters = from_json(params["training"], TrainingParameters)
rainbow_parameters = from_json(params["rainbow"], RainbowDQNParameters)
if "evaluation" in params:
evaluation_parameters = from_json(params["evaluation"], EvaluationParameters)
else:
evaluation_parameters = EvaluationParameters()

model_params = DiscreteActionModelParameters(
actions=action_names,
rl=rl_parameters,
training=training_parameters,
rainbow=rainbow_parameters,
evaluation=evaluation_parameters,
)
state_normalization = BaseWorkflow.read_norm_file(params["state_norm_data_path"])

Expand Down

0 comments on commit be985fb

Please sign in to comment.