Skip to content

Commit

Permalink
Adds Conv2dDescriptor::set_group_count (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed May 6, 2023
1 parent bb2aa65 commit 4f079e2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/cudnn/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ pub unsafe fn set_convolution2d_descriptor(
.result()
}

/// Set See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionMathType).
/// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionMathType).
/// # Safety
/// `desc` must NOT have been freed already
pub unsafe fn set_convolution_math_type(
Expand All @@ -248,6 +248,16 @@ pub unsafe fn set_convolution_math_type(
sys::cudnnSetConvolutionMathType(desc, math_type).result()
}

/// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionGroupCount)
/// # Safety
/// `desc` must NOT have been freed already
pub unsafe fn set_convolution_group_count(
desc: sys::cudnnConvolutionDescriptor_t,
group_count: i32,
) -> Result<(), CudnnError> {
sys::cudnnSetConvolutionGroupCount(desc, group_count).result()
}

/// Destroys a descriptor. See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnDestroyConvolutionDescriptor).
/// # Safety
/// `desc` must NOT have been already freed.
Expand Down
6 changes: 6 additions & 0 deletions src/cudnn/safe/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ impl<T> Conv2dDescriptor<T> {
pub fn set_math_type(&mut self, math_type: sys::cudnnMathType_t) -> Result<(), CudnnError> {
unsafe { result::set_convolution_math_type(self.desc, math_type) }
}

/// Set's the group count for this convolution. Refer to [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetConvolutionGroupCount)
/// for more information.
pub fn set_group_count(&mut self, group_count: i32) -> Result<(), CudnnError> {
unsafe { result::set_convolution_group_count(self.desc, group_count) }
}
}

impl<T> Drop for Conv2dDescriptor<T> {
Expand Down

0 comments on commit 4f079e2

Please sign in to comment.