Skip to content

Commit

Permalink
Merge 25849f3 into e84ea5f
Browse files Browse the repository at this point in the history
  • Loading branch information
toslunar committed Oct 17, 2017
2 parents e84ea5f + 25849f3 commit 365e2ed
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
10 changes: 6 additions & 4 deletions chainerrl/action_value.py
Expand Up @@ -17,6 +17,8 @@
from chainer import functions as F
import numpy as np

from chainerrl.misc.chainer_compat import matmul_v3


class ActionValue(with_metaclass(ABCMeta, object)):
"""Struct that holds state-fixed Q-functions and its subproducts.
Expand Down Expand Up @@ -146,10 +148,10 @@ def max(self):
def evaluate_actions(self, actions):
u_minus_mu = actions - self.mu
a = - 0.5 * \
F.batch_matmul(F.batch_matmul(
u_minus_mu, self.mat, transa=True), u_minus_mu)
return (F.reshape(a, (self.batch_size,)) +
F.reshape(self.v, (self.batch_size,)))
matmul_v3(matmul_v3(
u_minus_mu[:, None, :], self.mat),
u_minus_mu[:, :, None])[:, 0, 0]
return a + F.reshape(self.v, (self.batch_size,))

def compute_advantage(self, actions):
return self.evaluate_actions(actions) - self.max
Expand Down
24 changes: 24 additions & 0 deletions chainerrl/misc/chainer_compat.py
@@ -0,0 +1,24 @@
from distutils.version import StrictVersion
import pkg_resources

import chainer.functions as F


chainer_version = StrictVersion(
pkg_resources.get_distribution("chainer").version)

if chainer_version < StrictVersion('3.0.0a1'):
"""Chainer's PR #2426 changed the behavior of matmul
Simulate the newer behavior by functions in Chainer v2
"""
def matmul_v3(a, b, **kwargs):
if (a.ndim, b.ndim) == (3, 3):
return F.batch_matmul(a, b, **kwargs)
elif (a.ndim, b.ndim) == (2, 2):
return F.matmul(a, b, **kwargs)
else:
raise Exception("unsupported shapes: {}, {}".format(
a.shape, b.shape))
else:
matmul_v3 = F.matmul
5 changes: 3 additions & 2 deletions chainerrl/q_functions/state_q_functions.py
Expand Up @@ -17,6 +17,7 @@
from chainerrl.functions.lower_triangular_matrix import lower_triangular_matrix
from chainerrl.links.mlp import MLP
from chainerrl.links.mlp_bn import MLPBN
from chainerrl.misc.chainer_compat import matmul_v3
from chainerrl.q_function import StateQFunction
from chainerrl.recurrent import RecurrentChainMixin

Expand Down Expand Up @@ -154,7 +155,7 @@ def __call__(self, state):
if hasattr(self, 'mat_non_diag'):
mat_non_diag = self.mat_non_diag(h)
tril = lower_triangular_matrix(mat_diag, mat_non_diag)
mat = F.batch_matmul(tril, tril, transb=True)
mat = matmul_v3(tril, tril, transb=True)
else:
mat = F.expand_dims(mat_diag ** 2, axis=2)
return QuadraticActionValue(
Expand Down Expand Up @@ -213,7 +214,7 @@ def __call__(self, state):
if hasattr(self, 'mat_non_diag'):
mat_non_diag = self.mat_non_diag(h)
tril = lower_triangular_matrix(mat_diag, mat_non_diag)
mat = F.batch_matmul(tril, tril, transb=True)
mat = matmul_v3(tril, tril, transb=True)
else:
mat = F.expand_dims(mat_diag ** 2, axis=2)
return QuadraticActionValue(
Expand Down

0 comments on commit 365e2ed

Please sign in to comment.