From 24a78e17250bb6fe8b4a11c7a29764addcdd1fdd Mon Sep 17 00:00:00 2001 From: muupan Date: Mon, 26 Aug 2019 14:27:34 +0900 Subject: [PATCH] Use cupyx.scatter_add instead of cupy.scatter_add as the latter is deprecated by cupy from v4 --- chainerrl/agents/categorical_dqn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chainerrl/agents/categorical_dqn.py b/chainerrl/agents/categorical_dqn.py index 0a0ed0dae..ebc4fa404 100644 --- a/chainerrl/agents/categorical_dqn.py +++ b/chainerrl/agents/categorical_dqn.py @@ -51,7 +51,8 @@ def _apply_categorical_projection(y, y_probs, z): assert u.shape == (batch_size, n_atoms) if cuda.available and xp is cuda.cupy: - scatter_add = xp.scatter_add + import cupyx + scatter_add = cupyx.scatter_add else: scatter_add = np.add.at