diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 6d7319a..90fcbac 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -238,7 +238,7 @@ pub unsafe fn set_convolution2d_descriptor( .result() } -/// Set See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionMathType). +/// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionMathType). /// # Safety /// `desc` must NOT have been freed already pub unsafe fn set_convolution_math_type( @@ -248,6 +248,16 @@ pub unsafe fn set_convolution_math_type( sys::cudnnSetConvolutionMathType(desc, math_type).result() } +/// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionGroupCount) +/// # Safety +/// `desc` must NOT have been freed already +pub unsafe fn set_convolution_group_count( + desc: sys::cudnnConvolutionDescriptor_t, + group_count: i32, +) -> Result<(), CudnnError> { + sys::cudnnSetConvolutionGroupCount(desc, group_count).result() +} + /// Destroys a descriptor. See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnDestroyConvolutionDescriptor). /// # Safety /// `desc` must NOT have been already freed. diff --git a/src/cudnn/safe/conv.rs b/src/cudnn/safe/conv.rs index a950b03..113785b 100644 --- a/src/cudnn/safe/conv.rs +++ b/src/cudnn/safe/conv.rs @@ -95,6 +95,12 @@ impl Conv2dDescriptor { pub fn set_math_type(&mut self, math_type: sys::cudnnMathType_t) -> Result<(), CudnnError> { unsafe { result::set_convolution_math_type(self.desc, math_type) } } + + /// Set's the group count for this convolution. Refer to [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionGroupCount) + /// for more information. + pub fn set_group_count(&mut self, group_count: i32) -> Result<(), CudnnError> { + unsafe { result::set_convolution_group_count(self.desc, group_count) } + } } impl Drop for Conv2dDescriptor {