Skip to content

Commit

Permalink
Pure-Python implementations of patch extraction (#818)
Browse files Browse the repository at this point in the history
* Add "clean" target to makefile

Removes the .so files from python C-extensions

* Delete the Cython patch extraction code

This is in preparation for adding a Pure Python implementation

* Tiny speedup in interpolation code

Pre-allocate the output array to avoid allocating the output
memory twice - once during sampling and once during concatenation

* Change to pure Python patch extraction

This is slower than the Cython implementation but reduces
the maintenance of the package as a whole by removing the need
for C-extension compilation. It also adds the ability to
interpolate patches with strategies other than nearest neighbour
aka slicing. For example, bilinear interpolation can now be
used. Note this also changes how boundaries are handled as
patches are now always returned and filled with a constant
value rather than being truncated.

For reference, on the masked Takeo image - extracting patches
around the landmarks previously took ~300us on my 2013 Macbook Pro
and now takes ~550us - so close to twice as slow.

Co-authored-by: Patrick Snape <708474+patricksnape@users.noreply.github.com>
  • Loading branch information
jabooth and patricksnape committed Dec 31, 2019
1 parent 6006c32 commit 48237d9
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 354 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
execute:
python setup.py build_ext --inplace

clean:
find . -name "*.so" -delete
47 changes: 35 additions & 12 deletions menpo/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from menpo.visualize.base import ImageViewer, LandmarkableViewable, Viewable

from .interpolation import scipy_interpolation

try:
from .interpolation import cython_interpolation
except ImportError:
warn('Falling back to scipy interpolation for affine warps')
cython_interpolation = None
from .patches import extract_patches, set_patches
from .patches import extract_patches_with_slice, set_patches, extract_patches_by_sampling

# Cache the greyscale luminosity coefficients as they are invariant.
_greyscale_luminosity_coef = None
Expand Down Expand Up @@ -1389,7 +1390,8 @@ def constrain_points_to_bounds(self, points):
return bounded_points

def extract_patches(self, patch_centers, patch_shape=(16, 16),
sample_offsets=None, as_single_array=True):
sample_offsets=None, as_single_array=True,
order=0, mode='constant', cval=0.0):
r"""
Extract a set of patches from an image. Given a set of patch centers
and a patch size, patches are extracted from within the image, centred
Expand All @@ -1404,6 +1406,13 @@ def extract_patches(self, patch_centers, patch_shape=(16, 16),
Currently only 2D images are supported.
Note that the default is nearest neighbour sampling for the patches
which is achieved via slicing and is much more efficient than using
sampling/interpolation. Note that a significant performance decrease
will be measured if the ``order`` or ``mode`` parameters are modified
from ``order = 0`` and ``mode = 'constant'`` as internally sampling
will be used rather than slicing.
Parameters
----------
patch_centers : :map:`PointCloud`
Expand All @@ -1420,6 +1429,15 @@ def extract_patches(self, patch_centers, patch_shape=(16, 16),
`ndarray`, thus a single numpy array is returned containing each
patch. If ``False``, a `list` of ``n_center * n_offset``
:map:`Image` objects is returned representing each patch.
order : `int`, optional
The order of interpolation. The order has to be in the range [0,5].
See warp_to_shape for more information.
mode : ``{constant, nearest, reflect, wrap}``, optional
Points outside the boundaries of the input are filled according to
the given mode.
cval : `float`, optional
Used in conjunction with mode ``constant``, the value outside the
image boundaries.
Returns
-------
Expand All @@ -1437,16 +1455,21 @@ def extract_patches(self, patch_centers, patch_shape=(16, 16),
raise ValueError('Only two dimensional patch extraction is '
'currently supported.')

if sample_offsets is None:
sample_offsets = np.zeros([1, 2], dtype=np.intp)
if order == 0 and mode == 'constant':
# Fast path using slicing
single_array = extract_patches_with_slice(self.pixels,
patch_centers.points,
patch_shape,
offsets=sample_offsets,
cval=cval)
else:
sample_offsets = np.require(sample_offsets, dtype=np.intp)

patch_centers = np.require(patch_centers.points, dtype=np.float,
requirements=['C'])
single_array = extract_patches(self.pixels, patch_centers,
np.asarray(patch_shape, dtype=np.intp),
sample_offsets)
single_array = extract_patches_by_sampling(self.pixels,
patch_centers.points,
patch_shape,
offsets=sample_offsets,
order=order,
mode=mode,
cval=cval)

if as_single_array:
return single_array
Expand Down Expand Up @@ -1775,7 +1798,7 @@ def warp_to_shape(self, template_shape, transform, warp_landmarks=True,
as self, but with each landmark updated to the warped position.
order : `int`, optional
The order of interpolation. The order has to be in the range [0,5]
========= ====================
Order Interpolation
========= ====================
Expand Down
79 changes: 0 additions & 79 deletions menpo/image/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,82 +710,3 @@ def constrain_to_pointcloud(self, pointcloud, batch_size=None,
for k in range(self.n_dims)])
copy.pixels[slices].flat = point_in_pointcloud(pointcloud, indices)
return copy

def set_patches(self, patches, patch_centers, offset=None,
offset_index=None):
r"""
Set the values of a group of patches into the correct regions in a copy
of this image. Given an array of patches and a set of patch centers,
the patches' values are copied in the regions of the image that are
centred on the coordinates of the given centers.
The patches argument can have any of the two formats that are returned
from the `extract_patches()` and `extract_patches_around_landmarks()`
methods. Specifically it can be:
1. ``(n_center, n_offset, self.n_channels, patch_shape)`` `ndarray`
2. `list` of ``n_center * n_offset`` :map:`Image` objects
Currently only 2D images are supported.
Parameters
----------
patches : `ndarray` or `list`
The values of the patches. It can have any of the two formats that
are returned from the `extract_patches()` and
`extract_patches_around_landmarks()` methods. Specifically, it can
either be an ``(n_center, n_offset, self.n_channels, patch_shape)``
`ndarray` or a `list` of ``n_center * n_offset`` :map:`Image`
objects.
patch_centers : :map:`PointCloud`
The centers to set the patches around.
offset : `list` or `tuple` or ``(1, 2)`` `ndarray` or ``None``, optional
The offset to apply on the patch centers within the image.
If ``None``, then ``(0, 0)`` is used.
offset_index : `int` or ``None``, optional
The offset index within the provided `patches` argument, thus the
index of the second dimension from which to sample. If ``None``,
then ``0`` is used.
Raises
------
ValueError
If image is not 2D
ValueError
If offset does not have shape (1, 2)
Returns
-------
new_image : :map:`BooleanImage`
A new boolean image where the provided patch locations have been
set to the provided values.
"""
# parse arguments
if self.n_dims != 2:
raise ValueError('Only two dimensional patch insertion is '
'currently supported.')
if offset is None:
offset = np.zeros([1, 2], dtype=np.intp)
elif isinstance(offset, tuple) or isinstance(offset, list):
offset = np.asarray([offset])
offset = np.require(offset, dtype=np.intp)
if not offset.shape == (1, 2):
raise ValueError('The offset must be a tuple, a list or a '
'numpy.array with shape (1, 2).')
if offset_index is None:
offset_index = 0

# if patches is a list, convert it to array
if isinstance(patches, list):
patches = _convert_patches_list_to_single_array(
patches, patch_centers.n_points)

copy = self.copy()
# convert pixels to uint8 so that they get recognized by cython
tmp_pixels = copy.pixels.astype(np.uint8)
# convert patches to uint8 as well and set them to pixels
set_patches(patches.astype(np.uint8), tmp_pixels, patch_centers.points,
offset, offset_index)
# convert pixels back to bool
copy.pixels = tmp_pixels.astype(np.bool)
return copy
20 changes: 12 additions & 8 deletions menpo/image/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np

map_coordinates = None # expensive, from scipy.ndimage
from menpo.transform import Homogeneous

Expand Down Expand Up @@ -38,24 +39,27 @@ def scipy_interpolation(pixels, points_to_sample, mode='constant', order=1,
global map_coordinates
if map_coordinates is None:
from scipy.ndimage import map_coordinates # expensive
sampled_pixel_values = []
sampled_pixel_values = np.empty((pixels.shape[0], points_to_sample.shape[0]),
dtype=pixels.dtype)

# Loop over every channel in image - we know first axis is always channels
# Note that map_coordinates uses the opposite (dims, points) convention
# to us so we transpose
points_to_sample_t = points_to_sample.T
for i in range(pixels.shape[0]):
sampled_pixel_values.append(map_coordinates(pixels[i],
points_to_sample_t,
mode=mode,
order=order,
cval=cval))
sampled_pixel_values = [v.reshape([1, -1]) for v in sampled_pixel_values]
return np.concatenate(sampled_pixel_values, axis=0)
map_coordinates(pixels[i],
points_to_sample_t,
mode=mode,
order=order,
cval=cval,
output=sampled_pixel_values[i])
return sampled_pixel_values


try:
from menpo.external.skimage._warps_cy import _warp_fast


def cython_interpolation(pixels, template_shape, h_transform, mode='constant',
order=1, cval=0.):
r"""
Expand Down

0 comments on commit 48237d9

Please sign in to comment.