Skip to content

Commit

Permalink
Merge 3a29a98 into 590647a
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Sep 28, 2018
2 parents 590647a + 3a29a98 commit 2be1172
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
2 changes: 2 additions & 0 deletions chainerrl/agents/categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _apply_categorical_projection(y, y_probs, z):
# bj: (batch_size, n_atoms)
bj = (y - v_min) / delta_z
assert bj.shape == (batch_size, n_atoms)
# Avoid the error caused by inexact delta_z
bj = xp.clip(bj, 0, n_atoms - 1)

# l, u: (batch_size, n_atoms)
l, u = xp.floor(bj), xp.ceil(bj)
Expand Down
28 changes: 28 additions & 0 deletions tests/agents_tests/test_categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,41 @@ def _test(self, xp):
proj = categorical_dqn._apply_categorical_projection(y, y_probs, z)
xp.testing.assert_allclose(proj, proj_gt, atol=1e-5)

def _test_inexact_delta_z(self, xp):
v_min, v_max = (-1, 1)
n_atoms = 4
# delta_z=2/3=0.66666... is not exact
z = xp.linspace(v_min, v_max, num=n_atoms, dtype=np.float32)
y = xp.asarray([
[-1, -1, 1, 1],
[-1, 0, 1, 1],
], dtype=np.float32)
y_probs = xp.asarray([
[0.5, 0.1, 0.1, 0.3],
[0.5, 0.2, 0.0, 0.3],
], dtype=np.float32)
proj_gt = xp.asarray([
[0.6, 0.0, 0.0, 0.4],
[0.5, 0.1, 0.1, 0.3],
], dtype=np.float32)

proj = categorical_dqn._apply_categorical_projection(y, y_probs, z)
xp.testing.assert_allclose(proj, proj_gt, atol=1e-5)

def test_cpu(self):
self._test(np)

@testing.attr.gpu
def test_gpu(self):
self._test(chainer.cuda.cupy)

def test_inexact_delta_z_cpu(self):
self._test_inexact_delta_z(np)

@testing.attr.gpu
def test_inexact_delta_z_gpu(self):
self._test_inexact_delta_z(chainer.cuda.cupy)


def make_distrib_ff_q_func(env):
n_atoms = 51
Expand Down

0 comments on commit 2be1172

Please sign in to comment.