Skip to content

Commit

Permalink
Refactor/feature network (#269)
Browse files Browse the repository at this point in the history
* remove gradient accumulation from feature network

* update feature network tests

* whitespace

* optionally log a loss metric as part of approximation.step()

* update algorithms relying on FeatureNetwork

* update FeatureNetwork docstring

* make loss check more explicit

Co-authored-by: Nota, Christopher <cnota@irobot.com>
  • Loading branch information
cpnota and Nota, Christopher committed Apr 13, 2022
1 parent 7df8e05 commit 509af6b
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 89 deletions.
14 changes: 8 additions & 6 deletions all/agents/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,17 @@ def _train(self, next_states):
policy_gradient_loss = -(distribution.log_prob(actions) * advantages).mean()
entropy_loss = -distribution.entropy().mean()
policy_loss = policy_gradient_loss + self.entropy_loss_scaling * entropy_loss
loss = value_loss + policy_loss

# backward pass
self.v.reinforce(value_loss)
self.policy.reinforce(policy_loss)
self.features.reinforce()
loss.backward()
self.v.step(loss=value_loss)
self.policy.step(loss=policy_loss)
self.features.step()

# debugging
self.writer.add_loss('policy_gradient', policy_gradient_loss.detach())
self.writer.add_loss('entropy', entropy_loss.detach())
# record metrics
self.writer.add_scalar('entropy', -entropy_loss)
self.writer.add_scalar('normalized_value_error', value_loss / targets.var())

def _make_buffer(self):
return NStepAdvantageBuffer(
Expand Down
12 changes: 7 additions & 5 deletions all/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,17 @@ def _train_minibatch(self, states, actions, pi_0, advantages, targets):
policy_gradient_loss = self._clipped_policy_gradient_loss(pi_0, pi_i, advantages)
entropy_loss = -distribution.entropy().mean()
policy_loss = policy_gradient_loss + self.entropy_loss_scaling * entropy_loss
loss = value_loss + policy_loss

# backward pass
self.v.reinforce(value_loss)
self.policy.reinforce(policy_loss)
self.features.reinforce()
loss.backward()
self.v.step(loss=value_loss)
self.policy.step(loss=policy_loss)
self.features.step()

# debugging
self.writer.add_loss('policy_gradient', policy_gradient_loss.detach())
self.writer.add_loss('entropy', entropy_loss.detach())
self.writer.add_scalar('entropy', -entropy_loss)
self.writer.add_scalar('normalized_value_error', value_loss / targets.var())

def _clipped_policy_gradient_loss(self, pi_0, pi_i, advantages):
ratios = torch.exp(pi_i - pi_0)
Expand Down
8 changes: 5 additions & 3 deletions all/agents/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def _train(self, state, reward):
# compute losses
value_loss = mse_loss(values, targets)
policy_loss = -(advantages * self._distribution.log_prob(self._action)).mean()
loss = value_loss + policy_loss

# backward pass
self.v.reinforce(value_loss)
self.policy.reinforce(policy_loss)
self.features.reinforce()
loss.backward()
self.v.step(loss=value_loss)
self.policy.step(loss=policy_loss)
self.features.step()


VACTestAgent = A2CTestAgent
8 changes: 5 additions & 3 deletions all/agents/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def _train(self):
# compute losses
value_loss = mse_loss(values, targets)
policy_loss = -(advantages * log_pis).mean()
loss = value_loss + policy_loss

# backward pass
self.v.reinforce(value_loss)
self.policy.reinforce(policy_loss)
self.features.reinforce()
loss.backward()
self.v.step(loss=value_loss)
self.policy.step(loss=policy_loss)
self.features.step()

# cleanup
self._trajectories = []
Expand Down
12 changes: 8 additions & 4 deletions all/approximation/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,25 @@ def reinforce(self, loss):
self: The current Approximation object
'''
loss = self._loss_scaling * loss
self._writer.add_loss(self._name, loss.detach())
loss.backward()
self.step()
self.step(loss=loss)
return self

def step(self):
def step(self, loss=None):
'''
Given that a backward pass has been made, run an optimization step
Given that a backward pass has been made, run an optimization step.
Internally, this will perform most of the activities associated with a control loop
in standard machine learning environments, depending on the configuration of the object:
Gradient clipping, learning rate schedules, logging, checkpointing, etc.
Args:
loss (torch.Tensor, optional): The loss to log for this opdate step.
Returns:
self: The current Approximation object
'''
if loss is not None:
self._writer.add_loss(self._name, loss.detach())
self._clip_grad_norm()
self._optimizer.step()
self._optimizer.zero_grad()
Expand Down
54 changes: 2 additions & 52 deletions all/approximation/feature_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,12 @@

class FeatureNetwork(Approximation):
'''
A special type of Approximation that accumulates gradients before backpropagating them.
This is useful when features are shared between network heads.
The __call__ function caches the computation graph and detaches the output.
Then, various functions approximators may backpropagate to the output.
The reinforce() function will then backpropagate the accumulated gradients on the output
through the original computation graph.
An Approximation that accepts a state updates the observation key
based on the given model.
'''

def __init__(self, model, optimizer=None, name='feature', **kwargs):
model = FeatureModule(model)
super().__init__(model, optimizer, name=name, **kwargs)
self._cache = []
self._out = []

def __call__(self, states):
'''
Run a forward pass of the model and return the detached output.
Args:
state (all.environment.State): An environment State
Returns:
all.environment.State: An environment State with the computed features
'''
features = self.model(states)
graphs = features.observation
observation = graphs.detach()
observation.requires_grad = True
features['observation'] = observation
self._enqueue(graphs, observation)
return features

def reinforce(self):
'''
Backward pass of the model.
'''
graphs, grads = self._dequeue()
if graphs.requires_grad:
graphs.backward(grads)
self.step()

def _enqueue(self, features, out):
self._cache.append(features)
self._out.append(out)

def _dequeue(self):
graphs = []
grads = []
for graph, out in zip(self._cache, self._out):
if out.grad is not None:
graphs.append(graph)
grads.append(out.grad)
self._cache = []
self._out = []
return torch.cat(graphs), torch.cat(grads)


class FeatureModule(torch.nn.Module):
Expand Down
18 changes: 2 additions & 16 deletions all/approximation/feature_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from all.core import State
from all.approximation.feature_network import FeatureNetwork


STATE_DIM = 2


Expand Down Expand Up @@ -38,8 +39,7 @@ def test_backward(self):
states = self.features(self.states)
loss = torch.tensor(0)
loss = torch.sum(states.observation)
loss.backward()
self.features.reinforce()
self.features.reinforce(loss)
features = self.features(self.states)
expected = State({
'observation': torch.tensor([
Expand All @@ -60,20 +60,6 @@ def assert_state_equal(self, actual, expected):
tt.assert_almost_equal(actual.observation, expected.observation, decimal=2)
tt.assert_equal(actual.mask, expected.mask)

def test_identity_features(self):
model = nn.Sequential(nn.Identity())
features = FeatureNetwork(model, None, device='cpu')

# forward pass
x = State({'observation': torch.tensor([1., 2., 3.])})
y = features(x)
tt.assert_equal(y.observation, x.observation)

# backward pass shouldn't raise exception
loss = y.observation.sum()
loss.backward()
features.reinforce()


if __name__ == "__main__":
unittest.main()

0 comments on commit 509af6b

Please sign in to comment.