Skip to content

Commit

Permalink
[test]: update to Pipeline contain method to test callable presence d…
Browse files Browse the repository at this point in the history
…irectly without string naming and tests of Pipeline methods against scipy results
  • Loading branch information
mscaudill committed Aug 30, 2023
1 parent faebb2f commit 9233bcc
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 12 deletions.
18 changes: 10 additions & 8 deletions src/openseize/tools/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class Pipeline:
True
>>> result.shape
(4, 61500)
>>> # assert the pipeline contains the notch function
>>> notch in transformer
True
"""

def __init__(self) -> None:
Expand All @@ -65,8 +68,8 @@ def validate(self, caller: Callable, **kwargs) -> None:
unbound_cnt = len(sig.parameters) - len(bound.arguments)
if unbound_cnt > 1:
msg = ('Pipeline callers must have exactly one unbound argument.'
f'{caller.__name__} has {unbound_cnt} unbound arguments.')
raise ValueError(msg)
f' {caller.__name__} has {unbound_cnt} unbound arguments.')
raise TypeError(msg)

def append(self, caller: Callable, **kwargs) -> None:
"""Append a callable to this Pipeline.
Expand All @@ -88,19 +91,18 @@ def append(self, caller: Callable, **kwargs) -> None:
frozen = partial(caller, **kwargs)
self.callers.append(frozen)

def __contains__(self, name: str) -> bool:
def __contains__(self, caller: Callable) -> bool:
"""Returns True if func with name is in this Pipeline's callables.
Args:
name:
The name of a function to look up in this Pipeline.
caller:
A callable to find in this pipelines partial callers.
Returns:
True if named function is in callers and False otherwise.
True if caller is in callers and False otherwise.
"""

names = [caller.__name__ for caller in self.callers]
return name in names
return caller in [caller.func for caller in self.callers]

def __call__(self, data):
"""Apply this Pipeline's callables to an initial data argument.
Expand Down
149 changes: 145 additions & 4 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,155 @@
!pytest test_pipelines.py::<TEST NAME>
"""

import pickle
from itertools import permutations

import pytest
from pytest_lazyfixture import lazy_fixture
import numpy as np
import scipy.signal as sps

from openseize import producer
from openseize.filtering.iir import Notch
from openseize.filtering.fir import Kaiser
from openseize.resampling.resampling import downsample
from openseize.tools.pipeline import Pipeline

# the goal here is to build several pipelines and compare results with scipy
# then I need to test that pipelines support concurrency -- currently this will
# fail due to as_producer decorator in core that should be replaced with
# partials just like in spectra module.

@pytest.fixture(scope="module")
def rng():
"""Returns a numpy default_rng object for generating reproducible but
random ndarrays."""

seed = 0
return np.random.default_rng(seed)

@pytest.fixture(scope="module")
def random1D(rng):
"""Returns a random 1D array."""

return rng.random((230020,))

@pytest.fixture(scope='module', params=permutations(range(2)))
def random2D(rng, request):
"""Returns random 2D arrays with sample axis along each axis."""

axes = request.param
yield np.transpose(rng.random((100012, 6)), axes=axes)

@pytest.fixture(scope='module', params=permutations(range(3)))
def random3D(rng, request):
"""Returns random 3D arrays with samples along each axis."""

axes = request.param
yield np.transpose(rng.random((100012, 6, 3)), axes=axes)

@pytest.fixture(scope='module', params=permutations(range(4)))
def random4D(rng, request):
"""Returns random 4D arrays with samples along each axis."""

axes = request.param
yield np.transpose(rng.random((100012, 2, 3, 3)), axes=axes)

def test_validate():
"""Confirms TypeError when caller is appended with too few bound args."""

pipe = Pipeline()

def caller(a, b, c=0):
return a + b + c

with pytest.raises(TypeError) as exc:

pipe.append(caller, c=10)
assert exc.type is TypeError

def test_contains_functions():
"""Validates Pipeline's contain method for functions."""

pipe = Pipeline()

def f(a, b):
return a+b

def g(x, y=0):
return x**y

pipe.append(f, a=1, b=2)
pipe.append(g, y=10)

assert f in pipe
assert g in pipe

def test_contains_callables():
"""Validates Pipeline's contain method for callables."""

pipe = Pipeline()

notch = Notch(60, width=6, fs=5000)
pipe.append(notch, chunksize=10000, axis=-1)

assert notch in pipe

# use lazy fixtures to pass parameterized fixtures into test
@pytest.mark.parametrize('arr',
[
lazy_fixture('random1D'),
lazy_fixture('random2D'),
lazy_fixture('random3D'),
lazy_fixture('random4D'),
]
)
def test_call_method(arr):
"""Test that composed openseize callable return the same result as Scipy.
This test is superfulous because all of openseize's functions and callables
are tested in their respective testing modules (e.g. test_iir.py) but for
completeness we test again.
"""

axis = np.argmax(arr.shape)
pro = producer(arr, chunksize=1000, axis=axis)

# add notch & downsample
pipe = Pipeline()
notch = Notch(60, width=8, fs=1000)
pipe.append(notch, chunksize=1000, axis=axis, dephase=False)
pipe.append(downsample, M=10, fs=1000, chunksize=1000, axis=axis)

measured = np.concatenate([x for x in pipe(pro)], axis=axis)

# compare with scipy
b, a = sps.iirnotch(60, Q=60/8, fs=1000)
notched = sps.lfilter(b, a, arr, axis=axis)

# build a kaiser like the one openseize uses
cutoff = 1000 / (2*10) # fs / 2M
fstop = cutoff + cutoff / 10
fpass = cutoff - cutoff / 10
gpass, gstop = 0.1, 40
h = Kaiser(fpass, fstop, 1000, gpass, gstop).coeffs

downed = sps.resample_poly(notched, up=1, down=10, axis=axis, window=h)

assert np.allclose(measured, downed)

def test_pickleable():
"""Test that pipelines are picklable.
Note this test is only designed to ensure pipelines are pickleable NOT that
the contained callables are pickleable. For those test see test_concurrency.
"""

pipe = Pipeline()

# add notch & downsample
pipe = Pipeline()
notch = Notch(60, width=8, fs=5000)
pipe.append(notch, chunksize=10000, axis=-1)
pipe.append(downsample, M=10, fs=5000, chunksize=10000, axis=-1)

# test pickling of this pipeline
sbytes = pickle.dumps(pipe)
assert isinstance(sbytes, bytes)

0 comments on commit 9233bcc

Please sign in to comment.