Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract patches types #602

Merged
merged 4 commits into from Jul 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions menpo/cy_utils.pxd
@@ -0,0 +1,6 @@
cimport cython
cimport numpy as np


cdef inline np.dtype dtype_from_memoryview(cython.view.memoryview arr):
return np.dtype(arr.view.format)
105 changes: 44 additions & 61 deletions menpo/image/extract_patches.pyx
Expand Up @@ -3,11 +3,24 @@
import numpy as np
cimport numpy as np
cimport cython
from ..cy_utils cimport dtype_from_memoryview


ctypedef fused IMAGE_TYPES:
float
double
np.uint8_t


ctypedef fused CENTRE_TYPES:
float
double


@cython.boundscheck(False)
@cython.wraparound(False)
cdef void calc_augmented_centers(double[:, :] centres, Py_ssize_t[:, :] offsets,
cdef void calc_augmented_centers(CENTRE_TYPES[:, :] centres,
Py_ssize_t[:, :] offsets,
Py_ssize_t[:, :] augmented_centers):
cdef Py_ssize_t total_index = 0, i = 0, j = 0

Expand Down Expand Up @@ -76,39 +89,15 @@ cdef void calc_slices(Py_ssize_t[:, :] centres,

@cython.boundscheck(False)
@cython.wraparound(False)
cdef void slice_image(double[:, :, :] image,
Py_ssize_t n_channels,
Py_ssize_t n_centres,
Py_ssize_t n_offsets,
Py_ssize_t[:, :] ext_s_min,
Py_ssize_t[:, :] ext_s_max,
Py_ssize_t[:, :] ins_s_min,
Py_ssize_t[:, :] ins_s_max,
double[:, :, :, :, :] patches):
cdef Py_ssize_t total_index = 0, i = 0, j = 0

for i in range(n_centres):
for j in range(n_offsets):
patches[i,
j,
:,
ins_s_min[total_index, 0]:ins_s_max[total_index, 0],
ins_s_min[total_index, 1]:ins_s_max[total_index, 1]
] = \
image[:,
ext_s_min[total_index, 0]:ext_s_max[total_index, 0],
ext_s_min[total_index, 1]:ext_s_max[total_index, 1]]
total_index += 1


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef extract_patches(double[:, :, :] image, double[:, :] centres,
cpdef extract_patches(IMAGE_TYPES[:, :, :] image,
CENTRE_TYPES[:, :] centres,
Py_ssize_t[:] patch_shape, Py_ssize_t[:, :] offsets):
dtype = dtype_from_memoryview(image)
cdef:
Py_ssize_t n_centres = centres.shape[0]
Py_ssize_t n_offsets = offsets.shape[0]
Py_ssize_t n_augmented_centres = n_centres * n_offsets
object extents_size = [n_augmented_centres, 2]

Py_ssize_t half_patch_shape0 = patch_shape[0] / 2
Py_ssize_t half_patch_shape1 = patch_shape[1] / 2
Expand All @@ -120,45 +109,39 @@ cpdef extract_patches(double[:, :, :] image, double[:, :] centres,
Py_ssize_t image_shape1 = image.shape[2]
Py_ssize_t n_channels = image.shape[0]

Py_ssize_t total_index = 0, i = 0, j = 0

# Although it is faster to use malloc in this case, the change in syntax
# and the mental overhead of handling freeing memory is not considered
# worth it for these buffers. From simple tests it seems you only begin
# to see a performance difference when you have
# n_augmented_centres >~ 5000
Py_ssize_t[:, :] augmented_centers = np.empty([n_augmented_centres, 2], dtype=np.intp)
Py_ssize_t[:, :] ext_s_max = np.empty([n_augmented_centres, 2], dtype=np.intp)
Py_ssize_t[:, :] ext_s_min = np.empty([n_augmented_centres, 2], dtype=np.intp)
Py_ssize_t[:, :] ins_s_max = np.empty([n_augmented_centres, 2], dtype=np.intp)
Py_ssize_t[:, :] ins_s_min = np.empty([n_augmented_centres, 2], dtype=np.intp)

np.ndarray[double, ndim=5] patches = np.zeros([n_centres,
n_offsets,
n_channels,
patch_shape0,
patch_shape1])
Py_ssize_t[:, :] augmented_centers = np.empty(extents_size,
dtype=np.intp)
Py_ssize_t[:, :] ext_s_max = np.empty(extents_size, dtype=np.intp)
Py_ssize_t[:, :] ext_s_min = np.empty(extents_size, dtype=np.intp)
Py_ssize_t[:, :] ins_s_max = np.empty(extents_size, dtype=np.intp)
Py_ssize_t[:, :] ins_s_min = np.empty(extents_size, dtype=np.intp)

np.ndarray[IMAGE_TYPES, ndim=5] patches = np.zeros(
[n_centres, n_offsets, n_channels, patch_shape0, patch_shape1],
dtype=dtype)

calc_augmented_centers(centres, offsets, augmented_centers)
calc_slices(augmented_centers,
image_shape0,
image_shape1,
patch_shape0,
patch_shape1,
half_patch_shape0,
half_patch_shape1,
add_to_patch0,
add_to_patch1,
ext_s_min,
ext_s_max,
ins_s_min,
calc_slices(augmented_centers, image_shape0, image_shape1, patch_shape0,
patch_shape1, half_patch_shape0, half_patch_shape1,
add_to_patch0, add_to_patch1, ext_s_min, ext_s_max, ins_s_min,
ins_s_max)
slice_image(image,
n_channels,
n_centres,
n_offsets,
ext_s_min,
ext_s_max,
ins_s_min,
ins_s_max,
patches)

for i in range(n_centres):
for j in range(n_offsets):
patches[i, j, :,
ins_s_min[total_index, 0]:ins_s_max[total_index, 0],
ins_s_min[total_index, 1]:ins_s_max[total_index, 1]
] = \
image[:,
ext_s_min[total_index, 0]:ext_s_max[total_index, 0],
ext_s_min[total_index, 1]:ext_s_max[total_index, 1]]
total_index += 1

return patches
35 changes: 35 additions & 0 deletions menpo/image/test/image_extract_patches_test.py
@@ -1,10 +1,45 @@
import numpy as np
from nose.tools import assert_equals

import menpo.io as mio
from menpo.landmark import labeller, ibug_face_68
from menpo.shape import PointCloud


def test_double_type():
image = mio.import_builtin_asset('breakingbad.jpg')
patch_shape = (16, 16)
patches = image.extract_patches(image.landmarks['PTS'].lms,
patch_size=patch_shape)
assert(patches[0].pixels.dtype == np.float64)


def test_float_type():
image = mio.import_builtin_asset('breakingbad.jpg')
image.pixels = image.pixels.astype(np.float32)
patch_shape = (16, 16)
patches = image.extract_patches(image.landmarks['PTS'].lms,
patch_size=patch_shape)
assert(patches[0].pixels.dtype == np.float32)


def test_uint8_type():
image = mio.import_builtin_asset('breakingbad.jpg', normalise=False)
patch_shape = (16, 16)
patches = image.extract_patches(image.landmarks['PTS'].lms,
patch_size=patch_shape)
assert(patches[0].pixels.dtype == np.uint8)


def test_uint8_type_single_array():
image = mio.import_builtin_asset('breakingbad.jpg', normalise=False)
patch_shape = (16, 16)
patches = image.extract_patches(image.landmarks['PTS'].lms,
patch_size=patch_shape,
as_single_array=True)
assert(patches.dtype == np.uint8)


def test_squared_even_patches():
image = mio.import_builtin_asset('breakingbad.jpg')
patch_shape = (16, 16)
Expand Down
17 changes: 9 additions & 8 deletions menpo/io/input/base.py
Expand Up @@ -286,7 +286,7 @@ def import_pickles(pattern, max_pickles=None, verbose=False):
yield asset


def _import_builtin_asset(asset_name):
def _import_builtin_asset(asset_name, **kwargs):
r"""Single builtin asset (landmark or image) importer.

Imports the relevant builtin asset from the ``./data`` directory that
Expand All @@ -308,9 +308,11 @@ def _import_builtin_asset(asset_name):
# importing them both separately.
try:
return _import(asset_path, image_types,
landmark_ext_map=image_landmark_types)
landmark_ext_map=image_landmark_types,
importer_kwargs=kwargs)
except ValueError:
return _import(asset_path, image_landmark_types)
return _import(asset_path, image_landmark_types,
importer_kwargs=kwargs)


def ls_builtin_assets():
Expand All @@ -320,23 +322,22 @@ def ls_builtin_assets():
-------
list of strings
Filenames of all assets in the data directory shipped with Menpo

"""
return [p.name for p in data_dir_path().glob('*') if not p.is_dir()]


def import_builtin(x):

def execute():
return _import_builtin_asset(x)
def execute(**kwargs):
return _import_builtin_asset(x, **kwargs)

return execute


class BuiltinAssets(object):

def __call__(self, asset_name):
return _import_builtin_asset(asset_name)
def __call__(self, asset_name, **kwargs):
return _import_builtin_asset(asset_name, **kwargs)

import_builtin_asset = BuiltinAssets()

Expand Down
5 changes: 5 additions & 0 deletions menpo/io/test/io_import_test.py
Expand Up @@ -22,6 +22,11 @@ def test_breaking_bad_import():
assert(img.landmarks['PTS'].n_landmarks == 68)


def test_breaking_bad_import_kwargs():
img = mio.import_builtin_asset('breakingbad.jpg', normalise=False)
assert(img.pixels.dtype == np.uint8)


def test_takeo_import():
img = mio.import_builtin_asset('takeo.ppm')
assert(img.shape == (225, 150))
Expand Down