Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented fast median filter for CUDA using Wavelet Matrix, a constant-time, HDR-compatible method #3627

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 79 additions & 0 deletions modules/cudafilters/src/cuda/median_filter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@
#include "opencv2/core/cuda/saturate_cast.hpp"
#include "opencv2/core/cuda/border_interpolate.hpp"


// The CUB library is used for the Median Filter with Wavelet Matrix,
// which has become a standard library since CUDA 11.
#include "wavelet_matrix_feature_support_checks.h"
#ifdef __OPENCV_USE_WAVELET_MATRIX_FOR_MEDIAN_FILTER_CUDA__
#include "wavelet_matrix_multi.cuh"
#include "wavelet_matrix_2d.cuh"
#include "wavelet_matrix_float_supporter.cuh"
#endif


namespace cv { namespace cuda { namespace device
{
__device__ void histogramAddAndSub8(int* H, const int * hist_colAdd,const int * hist_colSub){
Expand Down Expand Up @@ -334,4 +345,72 @@ namespace cv { namespace cuda { namespace device

}}}


#ifdef __OPENCV_USE_WAVELET_MATRIX_FOR_MEDIAN_FILTER_CUDA__
namespace cv { namespace cuda { namespace device
{
using namespace wavelet_matrix_median;

template<int CH_NUM, typename T>
void medianFiltering_wavelet_matrix_gpu(const PtrStepSz<T> src, PtrStepSz<T> dst, int radius,cudaStream_t stream){

constexpr bool is_float = std::is_same<T, float>::value;
constexpr static int WORD_SIZE = 32;
constexpr static int ThW = (std::is_same<T, uint8_t>::value ? 8 : 4);
constexpr static int ThH = (std::is_same<T, uint8_t>::value ? 64 : 256);
using XYIdxT = uint32_t;
using XIdxT = uint16_t;
using WM_T = typename std::conditional<is_float, uint32_t, T>::type;
using MedianResT = typename std::conditional<is_float, T, std::nullptr_t>::type;
using WM2D_IMPL = WaveletMatrix2dCu5C<WM_T, CH_NUM, WaveletMatrixMultiCu4G<XIdxT, 512>, 512, WORD_SIZE>;

CV_Assert(src.cols == dst.cols);
CV_Assert(dst.step % sizeof(T) == 0);

WM2D_IMPL WM_cuda(src.rows, src.cols, is_float, false);
WM_cuda.res_cu = reinterpret_cast<WM_T*>(dst.ptr());

const size_t line_num = src.cols * CH_NUM;
if (is_float) {
WMMedianFloatSupporter::WMMedianFloatSupporter<float, CH_NUM, XYIdxT> float_supporter(src.rows, src.cols);
float_supporter.alloc();
for (int y = 0; y < src.rows; ++y) {
cudaMemcpy(float_supporter.val_in_cu + y * line_num, src.ptr(y), line_num * sizeof(T), cudaMemcpyDeviceToDevice);
}
const auto p = WM_cuda.get_nowcu_and_buf_byte_div32();
float_supporter.sort_and_set((XYIdxT*)p.first, p.second);
WM_cuda.construct(nullptr, stream, true);
WM_cuda.template median2d<ThW, ThH, MedianResT, false>(radius, dst.step / sizeof(T), (MedianResT*)float_supporter.get_res_table(), stream);
} else {
for (int y = 0; y < src.rows; ++y) {
cudaMemcpy(WM_cuda.src_cu + y * line_num, src.ptr(y), line_num * sizeof(T), cudaMemcpyDeviceToDevice);
}
WM_cuda.construct(nullptr, stream);
WM_cuda.template median2d<ThW, ThH, MedianResT, false>(radius, dst.step / sizeof(T), nullptr, stream);
}
WM_cuda.res_cu = nullptr;
if (!stream) {
cudaSafeCall( cudaDeviceSynchronize() );
}
}

template<typename T>
void medianFiltering_wavelet_matrix_gpu(const PtrStepSz<T> src, PtrStepSz<T> dst, int radius, const int num_channels, cudaStream_t stream){
if (num_channels == 1) {
medianFiltering_wavelet_matrix_gpu<1>(src, dst, radius, stream);
} else if (num_channels == 3) {
medianFiltering_wavelet_matrix_gpu<3>(src, dst, radius, stream);
} else if (num_channels == 4) {
medianFiltering_wavelet_matrix_gpu<4>(src, dst, radius, stream);
} else {
CV_Assert(num_channels == 1 || num_channels == 3 || num_channels == 4);
}
}

template void medianFiltering_wavelet_matrix_gpu(const PtrStepSz<uint8_t> src, PtrStepSz<uint8_t> dst, int radius, const int num_channels, cudaStream_t stream);
template void medianFiltering_wavelet_matrix_gpu(const PtrStepSz<uint16_t> src, PtrStepSz<uint16_t> dst, int radius, const int num_channels, cudaStream_t stream);
template void medianFiltering_wavelet_matrix_gpu(const PtrStepSz<float> src, PtrStepSz<float> dst, int radius, const int num_channels, cudaStream_t stream);
}}}
#endif // __OPENCV_USE_WAVELET_MATRIX_FOR_MEDIAN_FILTER_CUDA__

#endif