Skip to content

Commit

Permalink
Merge pull request #8078 from asi1024/refactor-convolve1d3o
Browse files Browse the repository at this point in the history
Refactor convolve1d3o
  • Loading branch information
takagi committed Jan 10, 2024
2 parents aceac2c + 30ac16d commit 2912651
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 136 deletions.
39 changes: 0 additions & 39 deletions cupyx/signal/_convolution/_convolution_utils.py

This file was deleted.

106 changes: 21 additions & 85 deletions cupyx/signal/_convolution/_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,97 +23,33 @@
DEALINGS IN THE SOFTWARE.
"""

import cupy as cp
from cupy._core._scalar import get_typename
from cupyx.signal._convolution import _convolution_utils


CONVOLVE1D3O_KERNEL = """
#include <cupy/complex.cuh>
///////////////////////////////////////////////////////////////////////////////
// CONVOLVE 1D3O //
///////////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void _cupy_convolve1D3O( const T *__restrict__ inp,
const int inpW,
const T *__restrict__ kernel,
const int kerW,
const int kerH,
const int kerD,
const int mode,
T *__restrict__ out,
const int outW ) {
const int tx { static_cast<int>( blockIdx.x * blockDim.x + threadIdx.x ) };
const int stride { static_cast<int>( blockDim.x * gridDim.x ) };
for ( int tid = tx; tid < outW; tid += stride ) {
T temp {};
if ( mode == 0 ) { // Valid
if ( tid >= 0 && tid < inpW ) {
for ( int i = 0; i < kerW; i++ ) {
for ( int j = 0; j < kerH; j++ ) {
for ( int k = 0; k < kerD; k++ ) {
temp += inp[tid + kerW - i - 1] * inp[tid + kerH - j - 1] * inp[tid + kerD - k - 1] * kernel[ (kerH * i + j) * kerD + k ];
}
}
}
}
}
out[tid] = temp;
}
}
""" # NOQA

CONVOLVE1D3O_MODULE = cp.RawModule(
code=CONVOLVE1D3O_KERNEL, options=('-std=c++11',),
name_expressions=[
'_cupy_convolve1D3O<float>',
'_cupy_convolve1D3O<double>',
'_cupy_convolve1D3O<complex<float>>',
'_cupy_convolve1D3O<complex<double>>',
])
import cupy


def _convolve1d3o_gpu(inp, out, ker, mode):

kernel = CONVOLVE1D3O_MODULE.get_function(
f'_cupy_convolve1D3O<{get_typename(out.dtype)}>')

threadsperblock = (out.shape[0] + 128 - 1) // 128,
blockspergrid = 128,
kernel_args = (
inp,
inp.shape[0],
ker,
*ker.shape,
mode,
out,
out.shape[0],
)
kernel(threadsperblock, blockspergrid, kernel_args)
_convolve1d3o_kernel = cupy.ElementwiseKernel(
'raw T in1, raw T in2, int32 W, int32 H, int32 D', 'T out',
"""
T temp {};
for (int x = 0; x < W; x++) {
for (int y = 0; y < H; y++) {
for (int z = 0; z < D; z++) {
temp += in1[i + W - x - 1] * in1[i + H - y - 1] *
in1[i + D - z - 1] * in2[(H * x + y) * D + z];
}
}
}
out = temp;
""",
"cupy_convolved3o",
)


def _convolve1d3o(in1, in2, mode):

val = _convolution_utils._valfrommode(mode)
assert val == _convolution_utils.VALID

# Promote inputs
promType = cp.promote_types(in1.dtype, in2.dtype)
in1 = in1.astype(promType)
in2 = in2.astype(promType)

assert mode == "valid"
out_dim = in1.shape[0] - max(in2.shape) + 1
out = cp.empty(out_dim, dtype=in1.dtype)

_convolve1d3o_gpu(in1, out, in2, val)

dtype = cupy.result_type(in1, in2)
out = cupy.empty(out_dim, dtype=dtype)
_convolve1d3o_kernel(in1, in2, *in2.shape, out)
return out


Expand Down
22 changes: 10 additions & 12 deletions tests/cupyx_tests/signal_tests/convolution_tests/test_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

def _convolve1d3o(in1, in2):
dtype = in1.dtype
ker_shape = in2.shape
out_dim = in1.shape[0] - max(ker_shape) + 1
W, H, D = in2.shape
size = in1.shape[0] - max(W, H, D) + 1
s = numpy.dtype(dtype).itemsize
from numpy.lib.stride_tricks import as_strided
X = numpy.flip(as_strided(in1, (out_dim, *ker_shape), (s, s, 0, 0)), 1)
Y = numpy.flip(as_strided(in1, (out_dim, *ker_shape), (s, 0, s, 0)), 2)
Z = numpy.flip(as_strided(in1, (out_dim, *ker_shape), (s, 0, 0, s)), 3)
return (X * Y * Z * in2).sum(axis=(1, 2, 3))
X = as_strided(in1, (size, W), (s, s))[:, ::-1]
Y = as_strided(in1, (size, H), (s, s))[:, ::-1]
Z = as_strided(in1, (size, D), (s, s))[:, ::-1]
return numpy.einsum('ix,iy,iz,xyz->i', X, Y, Z, in2)


class TestConvolve1d3o:
Expand All @@ -28,10 +28,9 @@ def test_convolve1d3o(self, dtype, xp, shape):
b = testing.shaped_random(shape, xp=xp, dtype=dtype, scale=2) - 1
if xp is cupy:
return signal.convolve1d3o(a, b)
elif xp is numpy:
return _convolve1d3o(a, b)
else:
assert False
assert xp is numpy
return _convolve1d3o(a, b)

@testing.for_complex_dtypes()
@testing.numpy_cupy_allclose(rtol=2e-3)
Expand All @@ -44,7 +43,6 @@ def test_convolve1d3o_complex(self, dtype, xp, shape):
shape, xp=xp, dtype=dtype, scale=2) - (1 + 1j)
if xp is cupy:
return signal.convolve1d3o(a, b)
elif xp is numpy:
return _convolve1d3o(a, b)
else:
assert False
assert xp is numpy
return _convolve1d3o(a, b)

0 comments on commit 2912651

Please sign in to comment.