Skip to content

Commit

Permalink
fix scaling bug in SoftDeterministicPolicy (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed May 16, 2020
1 parent 8628699 commit edddd8b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 4 deletions.
25 changes: 21 additions & 4 deletions all/policies/soft_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,28 @@ def _normal(self, outputs):

def _sample(self, normal):
raw = normal.rsample()
action = self._squash(raw)
log_prob = self._log_prob(normal, raw)
return self._squash(raw), log_prob

def _log_prob(self, normal, raw):
'''
Compute the log probability of a raw action after the action is squashed.
Both inputs act on the raw underlying distribution.
Because tanh_mean does not affect the density, we can ignore it.
However, tanh_scale will affect the relative contribution of each component.'
See Appendix C in the Soft Actor-Critic paper
Args:
normal (torch.distributions.normal.Normal): The "raw" normal distribution.
raw (torch.Tensor): The "raw" action.
Returns:
torch.Tensor: The probability of the raw action, accounting for the affects of tanh.
'''
log_prob = normal.log_prob(raw)
log_prob -= torch.log(1 - action.pow(2) + 1e-6)
log_prob = log_prob.sum(1)
return action, log_prob
log_prob -= torch.log(1 - torch.tanh(raw).pow(2) + 1e-6)
log_prob /= self._tanh_scale
return log_prob.sum(1)

def _squash(self, x):
return torch.tanh(x) * self._tanh_scale + self._tanh_mean
Expand Down
68 changes: 68 additions & 0 deletions all/policies/soft_deterministic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest
import torch
import numpy as np
import torch_testing as tt
from gym.spaces import Box
from all import nn
from all.environments import State
from all.policies import SoftDeterministicPolicy

STATE_DIM = 2
ACTION_DIM = 3

class TestSoftDeterministic(unittest.TestCase):
def setUp(self):
torch.manual_seed(2)
self.model = nn.Sequential(
nn.Linear0(STATE_DIM, ACTION_DIM * 2)
)
self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01)
self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1]))
self.policy = SoftDeterministicPolicy(
self.model,
self.optimizer,
self.space
)

def test_output_shape(self):
state = State(torch.randn(1, STATE_DIM))
action, log_prob = self.policy(state)
self.assertEqual(action.shape, (1, ACTION_DIM))
self.assertEqual(log_prob.shape, torch.Size([1]))

state = State(torch.randn(5, STATE_DIM))
action, log_prob = self.policy(state)
self.assertEqual(action.shape, (5, ACTION_DIM))
self.assertEqual(log_prob.shape, torch.Size([5]))

def test_step_one(self):
state = State(torch.randn(1, STATE_DIM))
self.policy(state)
self.policy.step()

def test_converge(self):
state = State(torch.randn(1, STATE_DIM))
target = torch.tensor([0.25, 0.5, -0.5])

for _ in range(0, 200):
action, _ = self.policy(state)
loss = ((target - action) ** 2).mean()
loss.backward()
self.policy.step()

self.assertLess(loss, 0.2)

def test_scaling(self):
self.space = Box(np.array([-10, -5, 100]), np.array([10, -2, 200]))
self.policy = SoftDeterministicPolicy(
self.model,
self.optimizer,
self.space
)
state = State(torch.randn(1, STATE_DIM))
action, log_prob = self.policy(state)
tt.assert_allclose(action, torch.tensor([[-3.09055, -4.752777, 188.98222]]))
tt.assert_allclose(log_prob, torch.tensor([-0.397002]), rtol=1e-4)

if __name__ == '__main__':
unittest.main()

0 comments on commit edddd8b

Please sign in to comment.