Skip to content

Commit

Permalink
Merge pull request #988 from vilyaair/convolution-groups
Browse files Browse the repository at this point in the history
Change `group` argument name of create_convolution_descriptor
  • Loading branch information
niboshi committed Mar 2, 2018
2 parents 0dc2481 + b08a970 commit 1c31561
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions cupy/cudnn.pyx
Expand Up @@ -114,7 +114,7 @@ cpdef _create_filter_descriptor(


cpdef _create_convolution_descriptor(
desc, pad, stride, dtype, mode, dilation, int group,
desc, pad, stride, dtype, mode, dilation, int groups,
bint use_tensor_core):
cdef int d0, d1, p0, p1, s0, s1
cdef vector.vector[int] c_pad, c_stride, c_dilation
Expand Down Expand Up @@ -149,8 +149,8 @@ cpdef _create_convolution_descriptor(
if use_tensor_core:
math_type = cudnn.CUDNN_TENSOR_OP_MATH
cudnn.setConvolutionMathType(desc, math_type)
if group > 1:
cudnn.setConvolutionGroupCount(desc, group)
if groups > 1:
cudnn.setConvolutionGroupCount(desc, groups)
else:
cudnn.setConvolution2dDescriptor_v4(desc, p0, p1, s0, s1, 1, 1, mode)

Expand Down Expand Up @@ -203,11 +203,12 @@ def create_convolution_descriptor(pad, stride, dtype,
mode=cudnn.CUDNN_CROSS_CORRELATION,
dilation=(1, 1),
use_tensor_core=False,
group=1):
groups=1):
desc = Descriptor(cudnn.createConvolutionDescriptor(),
py_cudnn.destroyConvolutionDescriptor)
_create_convolution_descriptor(
desc.value, pad, stride, dtype, mode, dilation, group, use_tensor_core)
desc.value, pad, stride, dtype, mode, dilation, groups,
use_tensor_core)
return desc


Expand Down

0 comments on commit 1c31561

Please sign in to comment.