Skip to content

Commit

Permalink
Merge pull request #197 from cpnota/identity_feature_network
Browse files Browse the repository at this point in the history
Support identity networks for feature networks
  • Loading branch information
jkterry1 committed Jan 11, 2021
2 parents 263e3b0 + f7f1674 commit 7a8860d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
5 changes: 4 additions & 1 deletion all/approximation/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Approximation():
gradient to this value in order prevent large updates and
improve stability.
See torch.nn.utils.clip_grad.
device (string, optional): The device that the model is on. If none is passed,
the device will be automatically determined based on model.parameters()
loss_scaling (float, optional): Multiplies the loss by this value before
performing a backwards pass. Useful when used with multi-headed networks
with shared feature layers.
Expand All @@ -54,14 +56,15 @@ def __init__(
optimizer=None,
checkpointer=None,
clip_grad=0,
device=None,
loss_scaling=1,
name='approximation',
scheduler=None,
target=None,
writer=DummyWriter(),
):
self.model = model
self.device = next(model.parameters()).device
self.device = device if device else next(model.parameters()).device
self._target = target or TrivialTarget()
self._scheduler = scheduler
self._target.init(model)
Expand Down
5 changes: 3 additions & 2 deletions all/approximation/feature_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def reinforce(self):
Backward pass of the model.
'''
graphs, grads = self._dequeue()
graphs.backward(grads)
self.step()
if graphs.requires_grad:
graphs.backward(grads)
self.step()

def _enqueue(self, features, out):
self._cache.append(features)
Expand Down
15 changes: 15 additions & 0 deletions all/approximation/feature_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ def assert_state_equal(self, actual, expected):
tt.assert_almost_equal(actual.observation, expected.observation, decimal=2)
tt.assert_equal(actual.mask, expected.mask)

def test_identity_features(self):
model = nn.Sequential(nn.Identity())
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
features = FeatureNetwork(model, optimizer, device='cpu')

# forward pass
x = State({'observation': torch.tensor([1., 2., 3.])})
y = features(x)
tt.assert_equal(y.observation, x.observation)

# backward pass shouldn't raise exception
loss = y.observation.sum()
loss.backward()
features.reinforce()


if __name__ == "__main__":
unittest.main()

0 comments on commit 7a8860d

Please sign in to comment.