Skip to content

Commit

Permalink
Merge 0eac74a into 5dab9d1
Browse files Browse the repository at this point in the history
  • Loading branch information
qtux committed Dec 10, 2018
2 parents 5dab9d1 + 0eac74a commit 4030fa2
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 2 deletions.
40 changes: 40 additions & 0 deletions examples/plot_function_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
'''
==================================
Simple FunctionTransformer Example
==================================
This example demonstrates how to execute arbitrary functions on time series data using the
FunctionTransformer.
'''

# Author: Matthias Gazzari
# License: BSD

from seglearn.transform import FunctionTransformer, SegmentXY
from seglearn.base import TS_Data

import numpy as np

def choose_cols(Xt, cols):
return [time_series[:, cols] for time_series in Xt]

# Two multivariate time series with 4 and 3 samples of 3 variables each
X = [
np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),
np.array([[30, 40, 50], [60, 70, 80], [90, 100, 110]]),
]
# Time series target
y = [
np.array([True, False, False, True]),
np.array([False, True, False]),
]

trans = FunctionTransformer(choose_cols, func_kwargs={"cols":[0,1]})
X = trans.fit_transform(X, y)

segment = SegmentXY(width=3, overlap=1)
X = segment.fit_transform(X, y)

print("X:", X)
print("y: ", y)
3 changes: 2 additions & 1 deletion seglearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
__all__ = ['TS_Data', 'FeatureRep', 'FeatureRepMix', 'PadTrunc', 'Interp', 'Pype', 'SegmentX',
'SegmentXY', 'SegmentXYForecast', 'TemporalKFold', 'temporal_split', 'check_ts_data',
'check_ts_data_with_ts_target', 'ts_stats', 'get_ts_data_parts', 'all_features',
'base_features', 'load_watch', 'TargetRunLengthEncoder', '__version__']
'base_features', 'load_watch', 'TargetRunLengthEncoder', 'FunctionTransformer',
'__version__']

__author__ = 'David Burns david.mo.burns@gmail.com'
77 changes: 77 additions & 0 deletions seglearn/tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Author: David Burns
# License: BSD

import pytest

import numpy as np

import seglearn.transform as transform
Expand Down Expand Up @@ -447,3 +449,78 @@ def test_feature_rep_mix():
Xt = uni_union.transform(X)
assert Xt.shape[0] == len(X)
assert len(uni_union.f_labels) == Xt.shape[1]


def test_function_transform():
constant = 10
identity = transform.FunctionTransformer()
def replace(Xt, value):
return np.ones(Xt.shape) * value
custom = transform.FunctionTransformer(replace, func_kwargs={"value": constant})

# univariate ts
X = np.random.rand(100, 10)
y = np.ones(100)

identity.fit(X, y)
Xtrans = identity.transform(X)
assert Xtrans is X

custom.fit(X, y)
Xtrans = custom.transform(X)
assert np.array_equal(Xtrans, np.ones(X.shape) * constant)

# multivariate ts
X = np.random.rand(100, 10, 4)
y = np.ones(100)

identity.fit(X, y)
Xtrans = identity.transform(X)
assert Xtrans is X

custom.fit(X, y)
Xtrans = custom.transform(X)
assert np.array_equal(Xtrans, np.ones(X.shape) * constant)

# ts with univariate contextual data
Xt = np.random.rand(100, 10, 4)
Xc = np.random.rand(100)
X = TS_Data(Xt, Xc)
y = np.ones(100)

identity.fit(X, y)
Xtrans = identity.transform(X)
assert Xtrans is X

custom.fit(X, y)
Xtrans = custom.transform(X)
Xtt, Xtc = get_ts_data_parts(Xtrans)
assert np.array_equal(Xtt, np.ones(Xt.shape) * constant)
assert Xtc is Xc

# ts with multivariate contextual data
Xt = np.random.rand(100, 10, 4)
Xc = np.random.rand(100, 3)
X = TS_Data(Xt, Xc)
y = np.ones(100)

identity.fit(X, y)
Xtrans = identity.transform(X)
assert Xtrans is X

custom.fit(X, y)
Xtrans = custom.transform(X)
Xtt, Xtc = get_ts_data_parts(Xtrans)
assert np.array_equal(Xtt, np.ones(Xt.shape) * constant)
assert Xtc is Xc

# test resampling
def resample(Xt):
return Xt.reshape(1, -1)

illegal_resampler = transform.FunctionTransformer(resample)
X = np.random.rand(100, 10)
y = np.ones(100)
illegal_resampler.fit(X, y)
with pytest.raises(ValueError):
Xtrans = illegal_resampler.transform(X)
72 changes: 71 additions & 1 deletion seglearn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .util import get_ts_data_parts, check_ts_data

__all__ = ['SegmentX', 'SegmentXY', 'SegmentXYForecast', 'PadTrunc', 'Interp', 'FeatureRep',
'FeatureRepMix']
'FeatureRepMix', 'FunctionTransformer']


class XyTransformerMixin(object):
Expand Down Expand Up @@ -1141,3 +1141,73 @@ def transform(self, X):
fts = np.column_stack([fts, Xc])

return fts

class FunctionTransformer(BaseEstimator, TransformerMixin):
'''
Transformer for applying a custom function to time series data.
Parameters
----------
func : function, optional (default=None)
the function to be applied to Xt, the time series part of X (contextual variables Xc are
passed through unaltered) - X remains unchanged if no function is supplied
func_kwargs : dictionary, optional (default={})
keyword arguments to be passed to the function call
Returns
-------
self : object
returns self
'''

def __init__(self, func=None, func_kwargs={}):
self.func = func
self.func_kwargs = func_kwargs

def fit(self, X, y=None):
'''
Fit the transform
Parameters
----------
X : array-like, shape [n_samples, ...]
time series data and (optionally) contextual data
y : None
there is no need of a target in a transformer, yet the pipeline API requires this
Returns
-------
self : object
returns self
'''
check_ts_data(X, y)
return self

def transform(self, X):
'''
Transforms the time series data based on the provided function. Note this transformation
must not change the number of samples in the data.
Parameters
----------
X : array-like, shape [n_samples, ...]
time series data and (optionally) contextual data
Returns
-------
Xt : array-like, shape [n_samples, ...]
transformed time series data
'''
if self.func is None:
return X
else:
Xt, Xc = get_ts_data_parts(X)
n_samples = len(Xt)
Xt = self.func(Xt, **self.func_kwargs)
if len(Xt) != n_samples:
raise ValueError("Changing the number of samples inside a FunctionTransformer is"
"disabled.")
if Xc is not None:
Xt = TS_Data(Xt, Xc)
return Xt

0 comments on commit 4030fa2

Please sign in to comment.