Skip to content

Commit

Permalink
tests: Add tests for patch_sampler
Browse files Browse the repository at this point in the history
Additionally, add tests for a dynamically created PickableSampler object
and imbalanced-learn sampler shuffling.
  • Loading branch information
qtux committed Nov 7, 2019
1 parent 3fcb4fc commit a61f90f
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions seglearn/tests/test_transform.py
Expand Up @@ -3,13 +3,16 @@

import pytest
import warnings
import pickle

import numpy as np

import seglearn.transform as transform
from seglearn.base import TS_Data
from seglearn.feature_functions import all_features, mean
from seglearn.util import get_ts_data_parts
from sklearn.utils import shuffle
from sklearn.base import BaseEstimator


def test_sliding_window():
Expand Down Expand Up @@ -718,3 +721,84 @@ def resample(Xt):
illegal_resampler.fit(X, y)
with pytest.raises(ValueError):
Xtrans = illegal_resampler.transform(X)

# MUST be defined in the global scope for pickling to work correctly
def mock_resample(ndarray):
return ndarray[:len(ndarray) // 2]
class MockImblearnSampler(BaseEstimator):
def __init__(self, mocked_param="mock"):
pass
@staticmethod
def _check_X_y(X, y):
return X, y, True
def fit_resample(self, X, y):
X, y, _ = self._check_X_y(X, y)
return mock_resample(X), mock_resample(y)

def test_patch_sampler():
# test patch_sampler on a class without a fit_resample function
class EmptyClass(object):
pass
with pytest.raises(TypeError):
transform.patch_sampler(EmptyClass)

# test patch_sampler on a mocked imbalanced-learn Sampler class
unpatched_sampler = MockImblearnSampler()
patched_sampler = transform.patch_sampler(MockImblearnSampler)(shuffle=True, random_state=0)
assert str(patched_sampler.__class__) != str(unpatched_sampler.__class__)
pickled_sampler = pickle.dumps(patched_sampler)
unpickled_sampler = pickle.loads(pickled_sampler)
assert str(patched_sampler.__class__) == str(unpickled_sampler.__class__)

# test representation
assert "mocked_param" in repr(patched_sampler)
assert "random_state" in repr(patched_sampler)
assert "shuffle" in repr(patched_sampler)

# multivariate ts
X = np.random.rand(100, 10, 4)
y = np.ones(100)
Xt, yt, _ = patched_sampler.transform(X, y)
assert Xt is X
assert yt is y
Xt, yt, _ = patched_sampler.fit_transform(X, y)
X, y = shuffle(mock_resample(X), mock_resample(y), random_state=0)
assert np.array_equal(Xt, X)
assert np.array_equal(yt, y)

# ts with multivariate contextual data
X = TS_Data(np.random.rand(100, 10, 4), np.random.rand(100, 3))
Xt_orig, _ = get_ts_data_parts(X)
y = np.ones(100)
Xt, yt, _ = patched_sampler.transform(X, y)
assert Xt is X
assert yt is y
Xt, yt, _ = patched_sampler.fit_transform(X, y)
Xtt, Xtc = get_ts_data_parts(Xt)
Xt_orig, y = shuffle(mock_resample(Xt_orig), mock_resample(y), random_state=0)
assert np.array_equal(Xtt, Xt_orig)
assert np.array_equal(yt, y)

# ts with univariate contextual data
X = TS_Data(np.random.rand(100, 10, 4), np.random.rand(100))
Xt_orig, _ = get_ts_data_parts(X)
y = np.ones(100)
Xt, yt, _ = patched_sampler.transform(X, y)
assert Xt is X
assert yt is y
Xt, yt, _ = patched_sampler.fit_transform(X, y)
Xtt, Xtc = get_ts_data_parts(Xt)
Xt_orig, y = shuffle(mock_resample(Xt_orig), mock_resample(y), random_state=0)
assert np.array_equal(Xtt, Xt_orig)
assert np.array_equal(yt, y)

# univariate ts
X = np.random.rand(100, 10)
y = np.ones(100)
Xt, yt, _ = patched_sampler.transform(X, y)
assert Xt is X
assert yt is y
Xt, yt, _ = patched_sampler.fit_transform(X, y)
X, y = shuffle(mock_resample(X), mock_resample(y), random_state=0)
assert np.array_equal(Xt, X)
assert np.array_equal(yt, y)

0 comments on commit a61f90f

Please sign in to comment.