Skip to content

Commit

Permalink
transform: Fix representation of a patched sampler
Browse files Browse the repository at this point in the history
Show the parameters of the imbalanced-learn base class and the
additional parameters introduced in the PickableSampler class (shuffle
and random_state). Ensure that these are forwarded to the base
imbalanced-learn Sampler class if the constructor requries them.

Adapt the tests and example accordingly.
  • Loading branch information
qtux committed May 14, 2019
1 parent 536c0d0 commit f8d2fb8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
2 changes: 2 additions & 0 deletions examples/plot_imblearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
('segment', SegmentXY(width=1, overlap=0)),
('resample', patch_sampler(RandomUnderSampler)()),
])
print("Pipeline:", pipe)

print("Calling a transform on the data does not change it ...")
Xf, yf = pipe.transform(X, y)
Expand Down Expand Up @@ -64,6 +65,7 @@ def score(self, X, y, sample_weight=None):
('feature', FeatureRep(features={"min":minimum})),
('estimator', VerboseDummyClassifier(strategy="constant", constant=True)),
])
print("Pipeline:", pipe)

print("Split the data into half training and half test data:")
X_train, X_test, y_train, y_test = temporal_split(X, y, 0.5)
Expand Down
10 changes: 9 additions & 1 deletion seglearn/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from seglearn.feature_functions import all_features, mean
from seglearn.util import get_ts_data_parts
from sklearn.utils import shuffle
from sklearn.base import BaseEstimator


def test_sliding_window():
Expand Down Expand Up @@ -677,7 +678,9 @@ def resample(Xt):
# MUST be defined in the global scope for pickling to work correctly
def mock_resample(ndarray):
return ndarray[:len(ndarray) // 2]
class MockImblearnSampler(object):
class MockImblearnSampler(BaseEstimator):
def __init__(self, mocked_param="mock"):
pass
@staticmethod
def _check_X_y(X, y):
return X, y, True
Expand All @@ -700,6 +703,11 @@ class EmptyClass(object):
unpickled_sampler = pickle.loads(pickled_sampler)
assert str(patched_sampler.__class__) == str(unpickled_sampler.__class__)

# test representation
assert "mocked_param" in repr(patched_sampler)
assert "random_state" in repr(patched_sampler)
assert "shuffle" in repr(patched_sampler)

# multivariate ts
X = np.random.rand(100, 10, 4)
y = np.ones(100)
Expand Down
22 changes: 19 additions & 3 deletions seglearn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,9 +1511,14 @@ def patch_sampler(sampler_class):
'''
Return a dynamically patched imbalanced-learn Sampler class compatible with Pype.
'''
if not hasattr(sampler_class, 'fit_resample') or not hasattr(sampler_class, '_check_X_y'):
raise TypeError('The sampler class to be patched must have a "fit_resample" and a'
' "_check_X_y" method')
conditions = [
hasattr(sampler_class, 'fit_resample'),
hasattr(sampler_class, '_check_X_y'),
hasattr(sampler_class, '_get_param_names'),
]
if not all(conditions):
raise TypeError('The sampler class to be patched must have a "fit_resample", a "_check_X_y"'
' method and a "_get_param_names" class method.')

class PickableSampler(sampler_class, XyTransformerMixin):
'''
Expand Down Expand Up @@ -1548,6 +1553,17 @@ def __init__(self, shuffle=False, random_state=None, **kwargs):
kwargs["random_state"] = random_state
super(PickableSampler, self).__init__(**kwargs)

@classmethod
def _get_param_names(cls):
'''
Get parameters of the imbalanced-learn Sampler base class and the additional arguments
of the dynamically derived class.
'''
init_signature = signature(getattr(cls, '__init__'))
parameters = [p.name for p in init_signature.parameters.values()
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
return sorted(set(sampler_class._get_param_names() + parameters))

@staticmethod
def _check_X_y(Xt, yt):
'''
Expand Down

0 comments on commit f8d2fb8

Please sign in to comment.