Skip to content

Commit

Permalink
Merge pull request #2598 from pfnet/fuse_args
Browse files Browse the repository at this point in the history
Fix fuse with *args without input_num
  • Loading branch information
okuta committed Apr 19, 2017
2 parents f35678a + 713bc3e commit 7b1aa1a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
9 changes: 5 additions & 4 deletions cupy/core/fusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import six
from six.moves import builtins
import string
Expand Down Expand Up @@ -489,8 +488,6 @@ def _get_fix_code(data_type, fixed_type, operation):


def _get_fusion(func, nin, reduce, post_map, identity, input_types, name=None):
if nin is None:
nin = len(inspect.getargspec(func).args)
in_vars = [_FusionVar(i, t) for i, t in enumerate(input_types)]
mem = _FusionMem(in_vars)
in_refs = [_FusionRef(_, mem) for _ in in_vars]
Expand Down Expand Up @@ -611,7 +608,11 @@ def is_cupy_data(a):
types = [_.dtype for _ in args]
key = tuple(types)
if key not in self._memo:
f = _get_fusion(self.func, self.input_num, self.reduce,
if self.input_num is not None:
nin = self.input_num
else:
nin = len(args)
f = _get_fusion(self.func, nin, self.reduce,
self.post_map, self.identity, types)
self._memo[key] = f
f = self._memo[key]
Expand Down
14 changes: 12 additions & 2 deletions tests/cupy_tests/core_tests/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,13 @@ def random_real(self, lower=-1000, higher=1000):
return numpy.random.rand(10, 10) * (higher - lower) + lower

def check(self, func, n, gen, *args):
self._check(func, n, gen, args, True)
self._check(func, n, gen, args, False)

@cupy.fuse(input_num=n)
def _check(self, func, n, gen, args, omit_nin):
nin = n if not omit_nin else None

@cupy.fuse(input_num=nin)
def f(*x):
return func(*x)

Expand All @@ -569,8 +574,13 @@ def f(*x):
numpy.testing.assert_array_almost_equal(n, fc.get())

def check_reduce(self, func, n, reduce_f, gen, *args):
self._check_reduce(func, n, reduce_f, gen, args, True)
self._check_reduce(func, n, reduce_f, gen, args, False)

def _check_reduce(self, func, n, reduce_f, gen, args, omit_nin):
nin = n if not omit_nin else None

@cupy.fuse(input_num=n, reduce=reduce_f)
@cupy.fuse(input_num=nin, reduce=reduce_f)
def f(*x):
return func(*x)

Expand Down

0 comments on commit 7b1aa1a

Please sign in to comment.