Skip to content
This repository has been archived by the owner on Aug 31, 2018. It is now read-only.

Commit

Permalink
Merge pull request #21 from jakirkham/add_generic
Browse files Browse the repository at this point in the history
Add generic filter
  • Loading branch information
jakirkham committed May 4, 2017
2 parents ca67647 + 01b50a5 commit 6252e6e
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dask_ndfilters/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
gaussian_laplace,
)

from dask_ndfilters.generic import (
generic_filter,
)

from dask_ndfilters.order import (
minimum_filter,
median_filter,
Expand Down
42 changes: 42 additions & 0 deletions dask_ndfilters/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-


import numbers

import numpy
import scipy.ndimage.filters

import dask_ndfilters._utils as _utils


@_utils._update_wrapper(scipy.ndimage.filters.generic_filter)
def generic_filter(input,
function,
size=None,
footprint=None,
mode='reflect',
cval=0.0,
origin=0,
extra_arguments=tuple(),
extra_keywords=dict()):
footprint = _utils._get_footprint(input.ndim, size, footprint)
origin = _utils._get_origin(footprint.shape, origin)
depth = _utils._get_depth(footprint.shape, origin)
depth, boundary = _utils._get_depth_boundary(footprint.ndim, depth, "none")

result = input.map_overlap(
scipy.ndimage.filters.generic_filter,
depth=depth,
boundary=boundary,
dtype=input.dtype,
name=scipy.ndimage.filters.generic_filter.__name__,
function=function,
footprint=footprint,
mode=mode,
cval=cval,
origin=origin,
extra_arguments=extra_arguments,
extra_keywords=extra_keywords
)

return result
196 changes: 196 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import

import pytest

import numpy as np
import scipy.ndimage.filters as sp_ndf

import dask.array as da
import dask.array.utils as dau

import dask_ndfilters as da_ndf


@pytest.mark.parametrize(
"da_func",
[
da_ndf.generic_filter,
]
)
@pytest.mark.parametrize(
"err_type, function, size, footprint, origin",
[
(RuntimeError, lambda x: x, None, None, 0),
(TypeError, lambda x: x, 1.0, None, 0),
(RuntimeError, lambda x: x, (1,), None, 0),
(RuntimeError, lambda x: x, [(1,)], None, 0),
(RuntimeError, lambda x: x, 1, np.ones((1,)), 0),
(RuntimeError, lambda x: x, None, np.ones((1,)), 0),
(RuntimeError, lambda x: x, None, np.ones((1, 0)), 0),
(RuntimeError, lambda x: x, 1, None, (0,)),
(RuntimeError, lambda x: x, 1, None, [(0,)]),
(ValueError, lambda x: x, 1, None, 1),
(TypeError, lambda x: x, 1, None, 0.0),
(TypeError, lambda x: x, 1, None, (0.0, 0.0)),
(TypeError, lambda x: x, 1, None, 1+0j),
(TypeError, lambda x: x, 1, None, (0+0j, 1+0j)),
]
)
def test_generic_filters_params(da_func,
err_type,
function,
size,
footprint,
origin):
a = np.arange(140.0).reshape(10, 14)
d = da.from_array(a, chunks=(5, 7))

with pytest.raises(err_type):
da_func(d,
function,
size=size,
footprint=footprint,
origin=origin)


@pytest.mark.parametrize(
"sp_func, da_func",
[
(sp_ndf.generic_filter, da_ndf.generic_filter),
]
)
@pytest.mark.parametrize(
"function, size, footprint",
[
(lambda x: x, 1, None),
(lambda x: x, (1, 1), None),
(lambda x: x, None, np.ones((1, 1))),
]
)
def test_generic_filter_identity(sp_func,
da_func,
function,
size,
footprint):
a = np.arange(140.0).reshape(10, 14)
d = da.from_array(a, chunks=(5, 7))

dau.assert_eq(
d, da_func(d, function, size=size, footprint=footprint)
)

dau.assert_eq(
sp_func(a, function, size=size, footprint=footprint),
da_func(d, function, size=size, footprint=footprint),
)


@pytest.mark.parametrize(
"sp_func, da_func",
[
(sp_ndf.generic_filter, da_ndf.generic_filter),
]
)
@pytest.mark.parametrize(
"function, size, footprint, origin",
[
(
lambda x: (np.array(x)**2).sum(),
2,
None,
0
),
(
lambda x: (np.array(x)**2).sum(),
None,
np.ones((2, 3)),
0
),
(
lambda x: (np.array(x)**2).sum(),
None,
np.ones((2, 3)),
(0, 1)
),
(
lambda x: (np.array(x)**2).sum(),
None,
np.ones((2, 3)),
(0, -1)
),
(
lambda x: (np.array(x)**2).sum(),
None,
(np.mgrid[-2: 2+1, -2: 2+1]**2).sum(axis=0) < 2.5**2,
0
),
(
lambda x: (np.array(x)**2).sum(),
None,
(np.mgrid[-2: 2+1, -2: 2+1]**2).sum(axis=0) < 2.5**2,
(1, 2)
),
(
lambda x: (np.array(x)**2).sum(),
None,
(np.mgrid[-2: 2+1, -2: 2+1]**2).sum(axis=0) < 2.5**2,
(-1, -2)
),
(
lambda x: (np.array(x)**2).sum(),
5,
None,
0
),
(
lambda x: (np.array(x)**2).sum(),
7,
None,
0
),
(
lambda x: (np.array(x)**2).sum(),
8,
None,
0
),
(
lambda x: (np.array(x)**2).sum(),
10,
None,
0
),
(
lambda x: (np.array(x)**2).sum(),
5,
None,
2
),
(
lambda x: (np.array(x)**2).sum(),
5,
None,
-2
),
]
)
def test_generic_filter_compare(sp_func,
da_func,
function,
size,
footprint,
origin):
a = np.arange(140.0).reshape(10, 14)
d = da.from_array(a, chunks=(5, 7))

dau.assert_eq(
sp_func(
a, function, size=size, footprint=footprint, origin=origin
),
da_func(
d, function, size=size, footprint=footprint, origin=origin
)
)

0 comments on commit 6252e6e

Please sign in to comment.