Skip to content

Commit

Permalink
correct scaling of log probability in soft policy (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Aug 5, 2021
1 parent e8f3f16 commit 6265d5d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion all/policies/soft_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _log_prob(self, normal, raw):
'''
log_prob = normal.log_prob(raw)
log_prob -= torch.log(1 - torch.tanh(raw).pow(2) + 1e-6)
log_prob /= self._tanh_scale
log_prob -= torch.log(self._tanh_scale)
return log_prob.sum(-1)

def _squash(self, x):
Expand Down
24 changes: 18 additions & 6 deletions all/policies/soft_deterministic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,28 @@ def test_converge(self):
self.assertLess(loss, 0.2)

def test_scaling(self):
self.space = Box(np.array([-10, -5, 100]), np.array([10, -2, 200]))
self.policy = SoftDeterministicPolicy(
torch.manual_seed(0)
state = State(torch.randn(1, STATE_DIM))
policy1 = SoftDeterministicPolicy(
self.model,
self.optimizer,
self.space
Box(np.array([-1., -1., -1.]), np.array([1., 1., 1.]))
)
action1, log_prob1 = policy1(state)

# reset seed and sample same thing, but with different scaling
torch.manual_seed(0)
state = State(torch.randn(1, STATE_DIM))
action, log_prob = self.policy(state)
tt.assert_allclose(action, torch.tensor([[-3.09055, -4.752777, 188.98222]]))
tt.assert_allclose(log_prob, torch.tensor([-0.397002]), rtol=1e-4)
policy2 = SoftDeterministicPolicy(
self.model,
self.optimizer,
Box(np.array([-2., -1., -1.]), np.array([2., 1., 1.]))
)
action2, log_prob2 = policy2(state)

# check scaling was correct
tt.assert_allclose(action1 * torch.tensor([2, 1, 1]), action2)
tt.assert_allclose(log_prob1 - np.log(2), log_prob2)


if __name__ == '__main__':
Expand Down

0 comments on commit 6265d5d

Please sign in to comment.