forked from NeuroTechX/moabb
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
346 lines (280 loc) · 10.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import importlib
import logging
import os
from collections import OrderedDict
from copy import deepcopy
from glob import glob
import numpy as np
import scipy.signal as scp
import yaml
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline
from moabb.analysis.results import get_string_rep
log = logging.getLogger(__name__)
def create_pipeline_from_config(config):
"""Create a pipeline from a config file.
takes a config dict as input and return the coresponding pipeline.
Parameters
----------
config : Dict.
Dict containing the config parameters.
Returns
-------
pipeline : Pipeline
sklearn Pipeline
"""
components = []
for component in config:
# load the package
mod = __import__(component["from"], fromlist=[component["name"]])
# create the instance
if "parameters" in component.keys():
params = component["parameters"]
else:
params = {}
instance = getattr(mod, component["name"])(**params)
components.append(instance)
pipeline = make_pipeline(*components)
return pipeline
def parse_pipelines_from_directory(dir_path):
"""
Takes in the path to a directory with pipeline configuration files and returns a dictionary
of pipelines.
Parameters
----------
dir_path: str
Path to directory containing pipeline config .yml or .py files
Returns
-------
pipeline_configs: dict
Generated pipeline config dictionaries. Each entry has structure:
'name': string
'pipeline': sklearn.BaseEstimator
'paradigms': list of class names that are compatible with said pipeline
"""
assert os.path.isdir(
os.path.abspath(dir_path)
), "Given pipeline path {} is not valid".format(dir_path)
# get list of config files
yaml_files = glob(os.path.join(dir_path, "*.yml"))
pipeline_configs = []
for yaml_file in yaml_files:
with open(yaml_file, "r") as _file:
content = _file.read()
# load config
config_dict = yaml.load(content, Loader=yaml.FullLoader)
ppl = create_pipeline_from_config(config_dict["pipeline"])
if "param_grid" in config_dict:
pipeline_configs.append(
{
"paradigms": config_dict["paradigms"],
"pipeline": ppl,
"name": config_dict["name"],
"param_grid": config_dict["param_grid"],
}
)
else:
pipeline_configs.append(
{
"paradigms": config_dict["paradigms"],
"pipeline": ppl,
"name": config_dict["name"],
}
)
# we can do the same for python defined pipeline
# TODO for python pipelines
python_files = glob(os.path.join(dir_path, "*.py"))
for python_file in python_files:
spec = importlib.util.spec_from_file_location("custom", python_file)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
pipeline_configs.append(foo.PIPELINE)
return pipeline_configs
def generate_paradigms(pipeline_configs, context=None, logger=log):
"""
Takes in a dictionary of pipelines configurations as returned by
parse_pipelines_from_directory and returns a dictionary of unique paradigms with all pipeline
configurations compatible with that paradigm.
Parameters
----------
pipeline_configs:
dictionary of pipeline configurations
context:
TODO:add description
logger:
logger
Returns
-------
paradigms: dict
Dictionary of dictionaries with the unique paradigms and the configuration of the
pipelines compatible with the paradigm
"""
context = context or {}
paradigms = OrderedDict()
for config in pipeline_configs:
if "paradigms" not in config.keys():
logger.error("{} must have a 'paradigms' key.".format(config))
continue
# iterate over paradigms
for paradigm in config["paradigms"]:
# check if it is in the context parameters file
if len(context) > 0:
if paradigm not in context.keys():
logger.debug(context)
logger.warning(
"Paradigm {} not in context file {}".format(
paradigm, context.keys()
)
)
if isinstance(config["pipeline"], BaseEstimator):
pipeline = deepcopy(config["pipeline"])
else:
logger.error(config["pipeline"])
raise (ValueError("pipeline must be a sklearn estimator"))
# append the pipeline in the paradigm list
if paradigm not in paradigms.keys():
paradigms[paradigm] = {}
# FIXME name are not unique
logger.debug("Pipeline: \n\n {} \n".format(get_string_rep(pipeline)))
paradigms[paradigm][config["name"]] = pipeline
return paradigms
def generate_param_grid(pipeline_configs, context=None, logger=log):
context = context or {}
param_grid = {}
for config in pipeline_configs:
if "paradigms" not in config:
logger.error("{} must have a 'paradigms' key.".format(config))
continue
# iterate over paradigms
if "param_grid" in config:
param_grid[config["name"]] = config["param_grid"]
return param_grid
class FilterBank(BaseEstimator, TransformerMixin):
"""Apply a given indentical pipeline over a bank of filter.
The pipeline provided with the constrictor will be appield on the 4th
axis of the input data. This pipeline should be used with a FilterBank
paradigm.
This can be used to build a filterbank CSP, for example::
pipeline = make_pipeline(FilterBank(estimator=CSP()), LDA())
Parameters
----------
estimator: sklean Estimator
the sklearn pipeline to apply on each band of the filter bank.
flatten: bool (True)
If True, output of each band are concatenated together on the feature
axis. if False, output are stacked.
"""
def __init__(self, estimator, flatten=True):
self.estimator = estimator
self.flatten = flatten
def fit(self, X, y=None):
assert X.ndim == 4
self.models = [
deepcopy(self.estimator).fit(X[..., i], y) for i in range(X.shape[-1])
]
return self
def transform(self, X):
assert X.ndim == 4
out = [self.models[i].transform(X[..., i]) for i in range(X.shape[-1])]
assert out[0].ndim == 2, (
"Each band must return a two dimensional "
f" matrix, currently have {out[0].ndim}"
)
if self.flatten:
return np.concatenate(out, axis=1)
else:
return np.stack(out, axis=2)
def __repr__(self):
estimator_name = type(self).__name__
estimator_prms = self.estimator.get_params()
return "{}(estimator={}, flatten={})".format(
estimator_name, estimator_prms, self.flatten
)
def filterbank(X, sfreq, idx_fb, peaks):
"""
Filter bank design for decomposing EEG data into sub-band components [1]_
Parameters
----------
X: ndarray of shape (n_trials, n_channels, n_samples) or (n_channels, n_samples)
EEG data to be processed
sfreq: int
Sampling frequency of the data.
idx_fb: int
Index of filters in filter bank analysis
peaks : list of len (n_classes)
Frequencies corresponding to the SSVEP components.
Returns
-------
y: ndarray of shape (n_trials, n_channels, n_samples)
Sub-band components decomposed by a filter bank
Reference:
.. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung,
"Enhancing detection of SSVEPs for a high-speed brain speller using
task-related component analysis",
IEEE Trans. Biomed. Eng, 65(1):104-112, 2018.
Code based on the Matlab implementation from authors of [1]_
(https://github.com/mnakanishi/TRCA-SSVEP).
"""
# Calibration data comes in batches of trials
if X.ndim == 3:
num_chans = X.shape[1]
num_trials = X.shape[0]
# Testdata come with only one trial at the time
elif X.ndim == 2:
num_chans = X.shape[0]
num_trials = 1
sfreq = sfreq / 2
min_freq = np.min(peaks)
max_freq = np.max(peaks)
if max_freq < 40:
top = 100
else:
top = 115
# Check for Nyquist
if top >= sfreq:
top = sfreq - 10
diff = max_freq - min_freq
# Lowcut frequencies for the pass band (depends on the frequencies of SSVEP)
# No more than 3dB loss in the passband
passband = [min_freq - 2 + x * diff for x in range(7)]
# At least 40db attenuation in the stopband
if min_freq - 4 > 0:
stopband = [
min_freq - 4 + x * (diff - 2) if x < 3 else min_freq - 4 + x * diff
for x in range(7)
]
else:
stopband = [2 + x * (diff - 2) if x < 3 else 2 + x * diff for x in range(7)]
Wp = [passband[idx_fb] / sfreq, top / sfreq]
Ws = [stopband[idx_fb] / sfreq, (top + 7) / sfreq]
N, Wn = scp.cheb1ord(Wp, Ws, 3, 40) # Chebyshev type I filter order selection.
B, A = scp.cheby1(N, 0.5, Wn, btype="bandpass") # Chebyshev type I filter design
y = np.zeros(X.shape)
if num_trials == 1: # For testdata
for ch_i in range(num_chans):
try:
# The arguments 'axis=0, padtype='odd', padlen=3*(max(len(B),len(A))-1)' correspond
# to Matlab filtfilt (https://dsp.stackexchange.com/a/47945)
y[ch_i, :] = scp.filtfilt(
B,
A,
X[ch_i, :],
axis=0,
padtype="odd",
padlen=3 * (max(len(B), len(A)) - 1),
)
except Exception as e:
print(e)
print(num_chans)
else:
for trial_i in range(num_trials): # Filter each trial sequentially
for ch_i in range(num_chans): # Filter each channel sequentially
y[trial_i, ch_i, :] = scp.filtfilt(
B,
A,
X[trial_i, ch_i, :],
axis=0,
padtype="odd",
padlen=3 * (max(len(B), len(A)) - 1),
)
return y