Skip to content

Commit

Permalink
Merge 2cf99d8 into c623de0
Browse files Browse the repository at this point in the history
  • Loading branch information
the-lay committed Jul 5, 2021
2 parents c623de0 + 2cf99d8 commit 7d71cea
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 1 deletion.
197 changes: 197 additions & 0 deletions cupyx/_texture.py
@@ -0,0 +1,197 @@
import cupy

from cupy import _core
from cupy.cuda import texture
from cupy.cuda import runtime


_affine_transform_2d_array_kernel = _core.ElementwiseKernel(
'U texObj, raw float32 m, uint64 width', 'T transformed_image',
'''
float3 pixel = make_float3(
(float)(i / width),
(float)(i % width),
1.0f
);
float x = dot(pixel, make_float3(m[0], m[1], m[2])) + .5f;
float y = dot(pixel, make_float3(m[3], m[4], m[5])) + .5f;
transformed_image = tex2D<T>(texObj, y, x);
''',
'affine_transformation_2d_array',
preamble='''
inline __host__ __device__ float dot(float3 a, float3 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z;
}
''')


_affine_transform_3d_array_kernel = _core.ElementwiseKernel(
'U texObj, raw float32 m, uint64 height, uint64 width',
'T transformed_volume',
'''
float4 voxel = make_float4(
(float)(i / (width * height)),
(float)((i % (width * height)) / width),
(float)((i % (width * height)) % width),
1.0f
);
float x = dot(voxel, make_float4(m[0], m[1], m[2], m[3])) + .5f;
float y = dot(voxel, make_float4(m[4], m[5], m[6], m[7])) + .5f;
float z = dot(voxel, make_float4(m[8], m[9], m[10], m[11])) + .5f;
transformed_volume = tex3D<T>(texObj, z, y, x);
''',
'affine_transformation_3d_array',
preamble='''
inline __host__ __device__ float dot(float4 a, float4 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
''')


def _create_texture_object(data,
address_mode: str,
filter_mode: str,
read_mode: str,
border_color=0):

if cupy.issubdtype(data.dtype, cupy.unsignedinteger):
fmt_kind = runtime.cudaChannelFormatKindUnsigned
elif cupy.issubdtype(data.dtype, cupy.integer):
fmt_kind = runtime.cudaChannelFormatKindSigned
elif cupy.issubdtype(data.dtype, cupy.floating):
fmt_kind = runtime.cudaChannelFormatKindFloat
else:
raise ValueError(f'Unsupported data type {data.dtype}')

if address_mode == 'nearest':
address_mode = runtime.cudaAddressModeClamp
elif address_mode == 'constant':
address_mode = runtime.cudaAddressModeBorder
else:
raise ValueError(
f'Unsupported address mode {address_mode} '
'(supported: constant, nearest)')

if filter_mode == 'nearest':
filter_mode = runtime.cudaFilterModePoint
elif filter_mode == 'linear':
filter_mode = runtime.cudaFilterModeLinear
else:
raise ValueError(
f'Unsupported filter mode {filter_mode} '
f'(supported: nearest, linear)')

if read_mode == 'element_type':
read_mode = runtime.cudaReadModeElementType
elif read_mode == 'normalized_float':
read_mode = runtime.cudaReadModeNormalizedFloat
else:
raise ValueError(
f'Unsupported read mode {read_mode} '
'(supported: element_type, normalized_float)')

texture_fmt = texture.ChannelFormatDescriptor(
data.itemsize * 8, 0, 0, 0, fmt_kind)
# CUDAArray: last dimension is the fastest changing dimension
array = texture.CUDAarray(texture_fmt, *data.shape[::-1])
res_desc = texture.ResourceDescriptor(
runtime.cudaResourceTypeArray, cuArr=array)
# TODO(the-lay): each dimension can have a different addressing mode
# TODO(the-lay): border color/value can be defined for up to 4 channels
tex_desc = texture.TextureDescriptor(
(address_mode, ) * data.ndim, filter_mode, read_mode,
borderColors=(border_color, ))
tex_obj = texture.TextureObject(res_desc, tex_desc)
array.copy_from(data)

return tex_obj


def affine_transformation(data,
transformation_matrix,
output_shape=None,
output=None,
interpolation: str = 'linear',
mode: str = 'constant',
border_value=0):
"""
Apply an affine transformation.
The method uses texture memory and supports only 2D and 3D float32 arrays
without channel dimension.
Args:
data (cupy.ndarray): The input array or texture object.
transformation_matrix (cupy.ndarray): Affine transformation matrix.
Must be a homogeneous and have shape ``(ndim + 1, ndim + 1)``.
output_shape (tuple of ints): Shape of output. If not specified,
the input array shape is used. Default is None.
output (cupy.ndarray or ~cupy.dtype): The array in which to place the
output, or the dtype of the returned array. If not specified,
creates the output array with shape of ``output_shape``. Default is
None.
interpolation (str): Specifies interpolation mode: ``'linear'`` or
``'nearest'``. Default is ``'linear'``.
mode (str): Specifies addressing mode for points outside of the array:
(`'constant'``, ``'nearest'``). Default is ``'constant'``.
border_value: Specifies value to be used for coordinates outside
of the array for ``'constant'`` mode. Default is 0.
Returns:
cupy.ndarray:
The transformed input.
.. seealso:: :func:`cupyx.scipy.ndimage.affine_transform`
"""

ndim = data.ndim
if (ndim < 2) or (ndim > 3):
raise ValueError(
'Texture memory affine transformation is defined only for '
'2D and 3D arrays without channel dimension.')

dtype = data.dtype
if dtype != cupy.float32:
raise ValueError(f'Texture memory affine transformation is available '
f'only for float32 data type (not {dtype})')

if interpolation not in ['linear', 'nearest']:
raise ValueError(
f'Unsupported interpolation {interpolation} '
f'(supported: linear, nearest)')

if transformation_matrix.shape != (ndim + 1, ndim + 1):
raise ValueError('Matrix must be have shape (ndim + 1, ndim + 1)')

texture_object = _create_texture_object(data,
address_mode=mode,
filter_mode=interpolation,
read_mode='element_type',
border_color=border_value)

if ndim == 2:
kernel = _affine_transform_2d_array_kernel
else:
kernel = _affine_transform_3d_array_kernel

if output_shape is None:
output_shape = data.shape

if output is None:
output = cupy.zeros(output_shape, dtype=dtype)
elif isinstance(output, (type, cupy.dtype)):
if output != cupy.float32:
raise ValueError(f'Texture memory affine transformation is '
f'available only for float32 data type (not '
f'{output})')
output = cupy.zeros(output_shape, dtype=output)
elif isinstance(output, cupy.ndarray):
if output.shape != output_shape:
raise ValueError('Output shapes do not match')
else:
raise ValueError('Output must be None, cupy.ndarray or cupy.dtype')

kernel(texture_object, transformation_matrix, *output_shape[1:], output)
return output
22 changes: 21 additions & 1 deletion cupyx/scipy/ndimage/interpolation.py
Expand Up @@ -5,6 +5,7 @@
import numpy

from cupy._core import internal
from cupyx import _texture
from cupyx.scipy.ndimage import _util
from cupyx.scipy.ndimage import _interp_kernels
from cupyx.scipy.ndimage import _spline_prefilter_core
Expand Down Expand Up @@ -309,7 +310,8 @@ def map_coordinates(input, coordinates, output=None, order=3,


def affine_transform(input, matrix, offset=0.0, output_shape=None, output=None,
order=3, mode='constant', cval=0.0, prefilter=True):
order=3, mode='constant', cval=0.0, prefilter=True, *,
texture_memory=False):
"""Apply an affine transformation.
Given an output image pixel index vector ``o``, the pixel value is
Expand Down Expand Up @@ -352,6 +354,14 @@ def affine_transform(input, matrix, offset=0.0, output_shape=None, output=None,
0.0
prefilter (bool): It is not used yet. It just exists for compatibility
with :mod:`scipy.ndimage`.
texture_memory (bool): If True, uses GPU texture memory. Supports only:
- 2D and 3D float32 arrays as input
- ``(ndim + 1, ndim + 1)`` homogeneous float32 transformation
matrix
- ``mode='constant'`` and ``mode='nearest'``
- ``order=0`` (nearest neighbor) and ``order=1`` (linear
interpolation)
Returns:
cupy.ndarray or None:
Expand All @@ -361,6 +371,16 @@ def affine_transform(input, matrix, offset=0.0, output_shape=None, output=None,
.. seealso:: :func:`scipy.ndimage.affine_transform`
"""

if texture_memory:
tm_interp = 'linear' if order > 0 else 'nearest'
return _texture.affine_transformation(data=input,
transformation_matrix=matrix,
output_shape=output_shape,
output=output,
interpolation=tm_interp,
mode=mode,
border_value=cval)

_check_parameter('affine_transform', order, mode)

offset = _util._fix_sequence_arg(offset, input.ndim, 'offset', float)
Expand Down
132 changes: 132 additions & 0 deletions tests/cupyx_tests/scipy_tests/ndimage_tests/test_interpolation.py
Expand Up @@ -272,6 +272,138 @@ def test_invalid_output_dtype(self):
with pytest.raises(RuntimeError):
ndi.affine_transform(x, xp.ones((0, 3)), output=output)

def test_invalid_texture_arguments(self):
aft = cupyx.scipy.ndimage.affine_transform
x = [cupy.ones((8, ) * n, dtype=cupy.float32) for n in range(1, 5)]

# (ndim < 2) and (ndim > 3) must fail
for i in [0, 3]:
with pytest.raises(ValueError):
aft(x[i], cupy.eye(i + 1, dtype=cupy.float32),
texture_memory=True)
# wrong input dtype
for dt in [cupy.float16, cupy.float64, cupy.int32, cupy.int64]:
with pytest.raises(ValueError):
aft(cupy.ones((8, 8), dtype=dt),
cupy.eye(3, dtype=cupy.float32), texture_memory=True)
# wrong matrix shape
for i in range(len(x)):
with pytest.raises(ValueError):
aft(x[i], cupy.eye(i, dtype=cupy.float32),
texture_memory=True)
# wrong output
with pytest.raises(ValueError):
aft(x[2], cupy.eye(3, dtype=cupy.float32), output='wrong',
texture_memory=True)
# wrong mode
for m in ['mirror', 'reflect', 'wrap', 'grid-mirror',
'grid-wrap', 'grid-constant', 'opencv']:
with pytest.raises(ValueError):
aft(x[2], cupy.eye(3, dtype=cupy.float32), mode=m,
texture_memory=True)
# non matching output_shape and output's shape
with pytest.raises(ValueError):
output = cupy.empty((7, 7, 7), dtype=cupy.float32)
aft(x[2], cupy.eye(3, dtype=cupy.float32), output_shape=(8, 8, 8),
output=output, texture_memory=True)
# non matching output_shape and input shape
with pytest.raises(ValueError):
aft(x[2], cupy.eye(3, dtype=cupy.float32), output_shape=(7, 7, 7),
texture_memory=True)


@testing.parameterize(*testing.product({
'output': [None, numpy.float32, 'empty'],
'output_shape': [None, 10],
'order': [0, 1],
'mode': ['constant', 'nearest'],
'shape': [(100, 100), (10, 20), (10, 10, 10), (10, 20, 30)],
'theta': [0, 90, 180, 270]
}))
@testing.gpu
@testing.with_requires('scipy')
class TestAffineTransformTextureMemory:

_multiprocess_can_split = True

def _2d_rotation_matrix(self, theta, rotation_center):
c, s = scipy.special.cosdg(theta), scipy.special.sindg(theta)
m = numpy.array([
[1, 0, rotation_center[0]],
[0, 1, rotation_center[1]],
[0, 0, 1]
], numpy.float32)
m = numpy.dot(m, numpy.array([
[c, -s, 0],
[s, c, 0],
[0, 0, 1]
], numpy.float32))
m = numpy.dot(m, numpy.array([
[1, 0, -rotation_center[0]],
[0, 1, -rotation_center[1]],
[0, 0, 1]
], numpy.float32))
return m

def _3d_rotation_matrix(self, theta, rotation_center):
c, s = scipy.special.cosdg(theta), scipy.special.sindg(theta)
m = numpy.array([
[1, 0, 0, rotation_center[0]],
[0, 1, 0, rotation_center[1]],
[0, 0, 1, rotation_center[2]],
[0, 0, 0, 1]
], numpy.float32)
m = numpy.dot(m, numpy.array([
[1, 0, 0, 0],
[0, c, -s, 0],
[0, s, c, 0],
[0, 0, 0, 1]
], numpy.float32))
m = numpy.dot(m, numpy.array([
[1, 0, 0, -rotation_center[0]],
[0, 1, 0, -rotation_center[1]],
[0, 0, 1, -rotation_center[2]],
[0, 0, 0, 1]
], numpy.float32))
return m

@testing.numpy_cupy_allclose(atol=0.1, scipy_name='scp')
def test_affine_transform_texture_memory(self, xp, scp):
a = xp.ones(self.shape, dtype=xp.float32)
center = numpy.divide(numpy.subtract(self.shape, 1), 2)

if len(self.shape) == 2:
matrix = self._2d_rotation_matrix(self.theta, center)
elif len(self.shape) == 3:
matrix = self._3d_rotation_matrix(self.theta, center)
else:
return pytest.xfail('Unsupported shape')

if self.output == 'empty':
output = xp.empty(self.shape, dtype=xp.float32)
if self.output_shape:
return pytest.skip('This combination is tested in '
'test_invalid_texture_arguments')
else:
output = self.output

if self.output_shape:
output_shape = (self.output_shape, ) * len(self.shape)
else:
output_shape = self.output_shape

if xp == cupy:
m = cupyx.scipy.ndimage.affine_transform
matrix = cupy.array(matrix)
return m(a, matrix, output_shape=output_shape,
output=output, order=self.order,
mode=self.mode, texture_memory=True)
else:
m = scp.ndimage.affine_transform
return m(a, matrix, output_shape=output_shape,
output=output, order=self.order,
mode=self.mode)


@testing.gpu
@testing.with_requires('opencv-python')
Expand Down

0 comments on commit 7d71cea

Please sign in to comment.