Skip to content

Commit

Permalink
Feature/batched calculations (#259)
Browse files Browse the repository at this point in the history
* updated default save freq

* added batched calculation code

* ran formatter

* fixed typo

* Revert "updated default save freq"

This reverts commit 32d5e91.

* removed flatten_for_execution

* fix failing unit test

Co-authored-by: Chris Nota <cpnota@gmail.com>
  • Loading branch information
benblack769 and cpnota committed Aug 24, 2021
1 parent aaa5403 commit 3804586
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 12 deletions.
13 changes: 9 additions & 4 deletions all/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class PPO(ParallelAgent):
epochs (int): Number of times to reuse each sample.
lam (float): The Generalized Advantage Estimate (GAE) decay parameter.
minibatches (int): The number of minibatches to split each batch into.
compute_batch_size (int): The batch size to use for computations that do not need backpropogation.
n_envs (int): Number of parallel actors/environments.
n_steps (int): Number of timesteps per rollout. Updates are performed once per rollout.
writer (Writer): Used for logging.
Expand All @@ -40,6 +41,7 @@ def __init__(
epsilon=0.2,
lam=0.95,
minibatches=4,
compute_batch_size=256,
n_envs=None,
n_steps=4,
writer=DummyWriter()
Expand All @@ -58,6 +60,7 @@ def __init__(
self.epsilon = epsilon
self.lam = lam
self.minibatches = minibatches
self.compute_batch_size = compute_batch_size
self.n_envs = n_envs
self.n_steps = n_steps
# private
Expand All @@ -82,9 +85,10 @@ def _train(self, next_states):
states, actions, advantages = self._buffer.advantages(next_states)

# compute target values
features = self.features.no_grad(states)
pi_0 = self.policy.no_grad(features).log_prob(actions)
targets = self.v.no_grad(features) + advantages
features = states.batch_execute(self.compute_batch_size, self.features.no_grad)
features['actions'] = actions
pi_0 = features.batch_execute(self.compute_batch_size, lambda s: self.policy.no_grad(s).log_prob(s['actions']))
targets = features.batch_execute(self.compute_batch_size, self.v.no_grad) + advantages

# train for several epochs
for _ in range(self.epochs):
Expand Down Expand Up @@ -139,7 +143,8 @@ def _make_buffer(self):
self.n_steps,
self.n_envs,
discount_factor=self.discount_factor,
lam=self.lam
lam=self.lam,
compute_batch_size=self.compute_batch_size
)


Expand Down
43 changes: 39 additions & 4 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ def mask(self):
return self['mask']

def __getitem__(self, key):
if isinstance(key, slice):
if isinstance(key, slice) or isinstance(key, int):
shape = self['mask'][key].shape
if len(shape) == 0:
return State({k: v[key] for (k, v) in self.items()}, device=self.device)
return StateArray({k: v[key] for (k, v) in self.items()}, shape, device=self.device)
if isinstance(key, int):
return State({k: v[key] for (k, v) in self.items()}, device=self.device)
if torch.is_tensor(key):
# some things may get los
# some things may get lost
d = {}
shape = self['mask'][key].shape
for (k, v) in self.items():
Expand All @@ -387,6 +387,41 @@ def shape(self):
def __len__(self):
return self.shape[0]

@classmethod
def cat(cls, state_array_list, axis=0):
'''Concatenates along batch dimention'''
if len(state_array_list) == 0:
raise ValueError("cat accepts a non-zero size list of StateArrays")

d = {}
state_size = sum(state_array.shape[axis] for state_array in state_array_list)
new_shape = list(state_array_list[0].shape)
new_shape[axis] = state_size
new_shape = tuple(new_shape)
keys = list(state_array_list[0].keys())
for key in keys:
d[key] = torch.cat([state_array[key] for state_array in state_array_list], axis=axis)
return StateArray(d, new_shape, device=state_array_list[0].device)

def batch_execute(self, minibatch_size, fn):
'''
execute in batches to reduce memory consumption
'''
data = self
batch_size = self.shape[0]
results = []
last = 0
while last < batch_size:
# load the indexes for the minibatch
first = last
last = min(first + minibatch_size, batch_size)
results.append(fn(data[first:last]))

if isinstance(results[0], StateArray):
return StateArray.cat(results)
else:
return torch.cat(results, axis=0)


class MultiagentState(State):
def __init__(self, x, device='cpu', **kwargs):
Expand Down
28 changes: 28 additions & 0 deletions all/core/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,34 @@ def test_view(self):
self.assertEqual(state.shape, (2, 3))
self.assertEqual(state.observation.shape, (2, 3, 3, 4))

def test_batch_exec(self):
zeros = StateArray.array([
State(torch.zeros((3, 4))),
State(torch.zeros((3, 4))),
State(torch.zeros((3, 4)))
])
ones_state = zeros.batch_execute(2, lambda x: StateArray({'observation': x.observation + 1}, x.shape, x.device))
ones_tensor = zeros.batch_execute(2, lambda x: x.observation + 1)
self.assertEqual(ones_state.shape, (3,))
self.assertTrue(torch.equal(ones_state.observation, torch.ones((3, 3, 4))))
self.assertTrue(torch.equal(ones_tensor, torch.ones((3, 3, 4))))

def test_cat(self):
i1 = StateArray({'observation': torch.zeros((2, 3, 4)), 'reward': torch.ones((2,))}, shape=(2,))
i2 = StateArray({'observation': torch.zeros((1, 3, 4)), 'reward': torch.ones((1,))}, shape=(1,))
cat = StateArray.cat([i1, i2])
self.assertEqual(cat.shape, (3,))
self.assertEqual(cat.observation.shape, (3, 3, 4))
self.assertEqual(cat.reward.shape, (3,))

def test_cat_axis1(self):
i1 = StateArray({'observation': torch.zeros((2, 3, 4)), 'reward': torch.ones((2, 3))}, shape=(2, 3))
i2 = StateArray({'observation': torch.zeros((2, 2, 4)), 'reward': torch.ones((2, 2))}, shape=(2, 2))
cat = StateArray.cat([i1, i2], axis=1)
self.assertEqual(cat.shape, (2, 5))
self.assertEqual(cat.observation.shape, (2, 5, 4))
self.assertEqual(cat.reward.shape, (2, 5))

def test_key_error(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
Expand Down
13 changes: 9 additions & 4 deletions all/memory/generalized_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def __init__(
n_steps,
n_envs,
discount_factor=1,
lam=1
lam=1,
compute_batch_size=256,
):
self.v = v
self.features = features
Expand All @@ -20,6 +21,7 @@ def __init__(
self.gamma = discount_factor
self.lam = lam
self._batch_size = self.n_steps * self.n_envs
self.compute_batch_size = compute_batch_size
self._states = []
self._actions = []
self._rewards = []
Expand All @@ -41,20 +43,23 @@ def store(self, states, actions, rewards):
else:
raise Exception("Buffer length exceeded: " + str(self.n_steps))

def advantages(self, states):
def advantages(self, next_states):
if len(self) < self._batch_size:
raise Exception("Not enough states received!")

self._states.append(states)
self._states.append(next_states)
states = State.array(self._states[0:self.n_steps + 1])
actions = torch.cat(self._actions[:self.n_steps], dim=0)
rewards = torch.stack(self._rewards[:self.n_steps])
_values = self.v.target(self.features.target(states))

_values = states.flatten().batch_execute(self.compute_batch_size, lambda s: self.v.target(self.features.target(s))).view(states.shape)
values = _values[0:self.n_steps]
next_values = _values[1:]

td_errors = rewards + self.gamma * next_values - values
advantages = self._compute_advantages(td_errors)
self._clear_buffers()

return (
states[0:-1].flatten(),
actions,
Expand Down

0 comments on commit 3804586

Please sign in to comment.