forked from NeuroTechX/moabb
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
493 lines (426 loc) · 17.1 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
"""Base class for a dataset."""
import abc
import logging
import re
import traceback
from dataclasses import dataclass
from inspect import signature
from pathlib import Path
from typing import Dict, Union
from sklearn.pipeline import Pipeline
from moabb.datasets.bids_interface import StepType, _interface_map
from moabb.datasets.preprocessing import SetRawAnnotations
log = logging.getLogger(__name__)
@dataclass
class CacheConfig:
"""
Configuration for caching of datasets.
Parameters
----------
save_*: bool
This flag specifies whether to save the output of the corresponding
step to disk.
use: bool
This flag specifies whether to use the disk cache in case it exists.
If True, the Raw or Epochs objects returned will not be preloaded
(this saves some time). Otherwise, they will be preloaded.
If use is False, the save_* and overwrite_* keys will be ignored.
overwrite_*: bool
This flag specifies whether to overwrite the disk cache in
case it exist.
path : None | str
Location of where to look for the data storing location.
If None, the environment variable or config parameter
``MNE_DATASETS_(signifier)_PATH`` is used. If it doesn't exist, the
"~/mne_data" directory is used. If the dataset
is not found under the given path, the data
will be automatically downloaded to the specified folder.
verbose:
Verbosity level. See mne.verbose.
"""
save_raw: bool = False
save_epochs: bool = False
save_array: bool = False
use: bool = False
overwrite_raw: bool = False
overwrite_epochs: bool = False
overwrite_array: bool = False
path: Union[str, Path] = None
verbose: str = None
@classmethod
def make(cls, dic: Union[None, Dict, "CacheConfig"] = None) -> "CacheConfig":
"""
Create a CacheConfig object from a dict or another CacheConfig object.
Examples
-------
Using default parameters:
>>> CacheConfig.make()
CacheConfig(save=True, use=True, overwrite=True, path=None)
From a dict:
>>> dic = {'save': False}
>>> CacheConfig.make(dic)
CacheConfig(save=False, use=True, overwrite=True, path=None)
"""
if dic is None:
return cls()
elif isinstance(dic, dict):
return cls(**dic)
elif isinstance(dic, cls):
return dic
else:
raise ValueError(f"Expected dict or CacheConfig, got {type(dic)}")
def apply_step(pipeline, obj):
"""Apply a pipeline to an object."""
if obj is None:
return None
try:
return pipeline.transform(obj)
except ValueError as error:
# no events received by RawToEpochs:
if str(error) == "No events found":
return None
raise error
def is_camel_kebab_case(name: str):
"""Check if a string is in CamelCase but can also contain dashes."""
return re.fullmatch(r"[a-zA-Z0-9\-]+", name) is not None
def is_abbrev(abbrev_name: str, full_name: str):
"""Check if abbrev_name is an abbreviation of full_name,
i.e. ifthe characters in abbrev_name are all in full_name
and in the same order. They must share the same capital letters."""
pattern = re.sub(r"([A-Za-z])", r"\1[a-z0-9\-]*", re.escape(abbrev_name))
return re.fullmatch(pattern, full_name) is not None
class BaseDataset(metaclass=abc.ABCMeta):
"""Abstract Moabb BaseDataset.
Parameters required for all datasets
parameters
----------
subjects: List of int
List of subject number (or tuple or numpy array)
sessions_per_subject: int
Number of sessions per subject (if varying, take minimum)
events: dict of strings
String codes for events matched with labels in the stim channel.
Currently imagery codes codes can include:
- left_hand
- right_hand
- hands
- feet
- rest
- left_hand_right_foot
- right_hand_left_foot
- tongue
- navigation
- subtraction
- word_ass (for word association)
code: string
Unique identifier for dataset, used in all plots.
The code should be in CamelCase.
interval: list with 2 entries
Imagery interval as defined in the dataset description
paradigm: ['p300','imagery', 'ssvep']
Defines what sort of dataset this is
doi: DOI for dataset, optional (for now)
"""
def __init__(
self,
subjects,
sessions_per_subject,
events,
code,
interval,
paradigm,
doi=None,
unit_factor=1e6,
):
"""Initialize function for the BaseDataset."""
try:
_ = iter(subjects)
except TypeError:
raise ValueError("subjects must be a iterable, like a list") from None
if not is_camel_kebab_case(code):
raise ValueError(
f"code {code!r} must be in Camel-KebabCase; "
"i.e. use CamelCase, and add dashes where absolutely necessary. "
"See moabb.datasets.base.is_camel_kebab_case for more information."
)
class_name = self.__class__.__name__.replace("_", "-")
if not is_abbrev(class_name, code):
log.warning(
f"The dataset class name {class_name!r} must be an abbreviation "
f"of its code {code!r}. "
"See moabb.datasets.base.is_abbrev for more information."
)
self.subject_list = subjects
self.n_sessions = sessions_per_subject
self.event_id = events
self.code = code
self.interval = interval
self.paradigm = paradigm
self.doi = doi
self.unit_factor = unit_factor
def _create_process_pipeline(self):
return Pipeline(
[
(StepType.RAW, SetRawAnnotations(self.event_id)),
]
)
def get_data(
self,
subjects=None,
cache_config=None,
process_pipeline=None,
):
"""
Return the data correspoonding to a list of subjects.
The returned data is a dictionary with the following structure::
data = {'subject_id' :
{'session_id':
{'run_id': run}
}
}
subjects are on top, then we have sessions, then runs.
A sessions is a recording done in a single day, without removing the
EEG cap. A session is constitued of at least one run. A run is a single
contiguous recording. Some dataset break session in multiple runs.
Processing steps can optionally be applied to the data using the
``*_pipeline`` arguments. These pipelines are applied in the
following order: ``raw_pipeline`` -> ``epochs_pipeline`` ->
``array_pipeline``. If a ``*_pipeline`` argument is ``None``,
the step will be skipped. Therefore, the ``array_pipeline`` may
either receive a :class:`mne.io.Raw` or a :class:`mne.Epochs` object
as input depending on whether ``epochs_pipeline`` is ``None`` or not.
Parameters
----------
subjects: List of int
List of subject number
cache_config: dict | CacheConfig
Configuration for caching of datasets. See ``CacheConfig``
for details.
process_pipeline: Pipeline | None
Optional processing pipeline to apply to the data.
To generate an adequate pipeline, we recommend using
:func:`moabb.utils.make_process_pipelines`.
This pipeline will receive :class:`mne.io.BaseRaw` objects.
The steps names of this pipeline should be elements of :class:`StepType`.
According to their name, the steps should either return a
:class:`mne.io.BaseRaw`, a :class:`mne.Epochs`, or a :func:`numpy.ndarray`.
This pipeline must be "fixed" because it will not be trained,
i.e. no call to ``fit`` will be made.
Returns
-------
data: Dict
dict containing the raw data
"""
if subjects is None:
subjects = self.subject_list
if not isinstance(subjects, list):
raise ValueError("subjects must be a list")
cache_config = CacheConfig.make(cache_config)
if process_pipeline is None:
process_pipeline = self._create_process_pipeline()
data = dict()
for subject in subjects:
if subject not in self.subject_list:
raise ValueError("Invalid subject {:d} given".format(subject))
data[subject] = self._get_single_subject_data_using_cache(
subject,
cache_config,
process_pipeline,
)
return data
def download(
self,
subject_list=None,
path=None,
force_update=False,
update_path=None,
accept=False,
verbose=None,
):
"""Download all data from the dataset.
This function is only useful to download all the dataset at once.
Parameters
----------
subject_list : list of int | None
List of subjects id to download, if None all subjects
are downloaded.
path : None | str
Location of where to look for the data storing location.
If None, the environment variable or config parameter
``MNE_DATASETS_(dataset)_PATH`` is used. If it doesn't exist, the
"~/mne_data" directory is used. If the dataset
is not found under the given path, the data
will be automatically downloaded to the specified folder.
force_update : bool
Force update of the dataset even if a local copy exists.
update_path : bool | None
If True, set the MNE_DATASETS_(dataset)_PATH in mne-python
config to the given path. If None, the user is prompted.
accept: bool
Accept licence term to download the data, if any. Default: False
verbose : bool, str, int, or None
If not None, override default verbose level
(see :func:`mne.verbose`).
"""
if subject_list is None:
subject_list = self.subject_list
for subject in subject_list:
# check if accept is needed
sig = signature(self.data_path)
if "accept" in [str(p) for p in sig.parameters]:
self.data_path(
subject=subject,
path=path,
force_update=force_update,
update_path=update_path,
verbose=verbose,
accept=accept,
)
else:
self.data_path(
subject=subject,
path=path,
force_update=force_update,
update_path=update_path,
verbose=verbose,
)
def _get_single_subject_data_using_cache(
self, subject, cache_config, process_pipeline
):
"""Load a single subject's data using cache.
Either load the data of a single subject from disk cache or from the
dataset object,
then eventually saves or overwrites the cache version depending on the
parameters.
"""
steps = list(process_pipeline.steps)
splitted_steps = [] # list of (cached_steps, remaining_steps)
if cache_config.use:
splitted_steps += [
(steps[:i], steps[i:]) for i in range(len(steps), 0, -1)
] # [len(steps)...1]
splitted_steps.append(
([], steps)
) # last option: if cached_steps is [], we don't use cache, i.e. i=0
for cached_steps, remaining_steps in splitted_steps:
sessions_data = None
# Load and eventually overwrite:
if len(cached_steps) == 0: # last option: we don't use cache
sessions_data = self._get_single_subject_data(subject)
assert sessions_data is not None # should not happen
else:
cache_type = cached_steps[-1][0]
interface = _interface_map[cache_type](
self,
subject,
path=cache_config.path,
process_pipeline=Pipeline(cached_steps),
verbose=cache_config.verbose,
)
if (
(cache_config.overwrite_raw and cache_type is StepType.RAW)
or (cache_config.overwrite_epochs and cache_type is StepType.EPOCHS)
or (cache_config.overwrite_array and cache_type is StepType.ARRAY)
):
interface.erase()
elif cache_config.use: # can't load if it was just erased
sessions_data = interface.load(
preload=False
) # None if cache inexistent
# If no cache was found or if it was erased, try the next option:
if sessions_data is None:
continue
# Apply remaining steps and save:
for step_idx, (step_type, process_pipeline) in enumerate(remaining_steps):
# apply one step:
sessions_data = {
session: {
run: apply_step(process_pipeline, raw)
for run, raw in runs.items()
}
for session, runs in sessions_data.items()
}
# save:
if (
(
cache_config.save_raw
and step_type is StepType.RAW
and (
(step_idx == len(remaining_steps) - 1)
or (remaining_steps[step_idx + 1][0] is not StepType.RAW)
)
) # we only save the last raw step
or (cache_config.save_epochs and step_type is StepType.EPOCHS)
or (cache_config.save_array and step_type is StepType.ARRAY)
):
interface = _interface_map[step_type](
self,
subject,
path=cache_config.path,
process_pipeline=Pipeline(
cached_steps + remaining_steps[: step_idx + 1]
),
verbose=cache_config.verbose,
)
try:
interface.save(sessions_data)
except Exception:
log.warning(
f"Failed to save {interface.__repr__()} "
f"to BIDS format:\n"
f"{' Pipeline: '.center(50, '#')}\n"
f"{interface.process_pipeline.__repr__()}\n"
f"{' Exception: '.center(50, '#')}\n"
f"{''.join(traceback.format_exc())}{'#' * 50}"
)
interface.erase() # remove partial cache
return sessions_data
raise ValueError("should not happen")
@abc.abstractmethod
def _get_single_subject_data(self, subject):
"""Return the data of a single subject.
The returned data is a dictionary with the following structure
data = {'session_id':
{'run_id': raw}
}
parameters
----------
subject: int
subject number
returns
-------
data: Dict
dict containing the raw data
"""
pass
@abc.abstractmethod
def data_path(
self, subject, path=None, force_update=False, update_path=None, verbose=None
):
"""Get path to local copy of a subject data.
Parameters
----------
subject : int
Number of subject to use
path : None | str
Location of where to look for the data storing location.
If None, the environment variable or config parameter
``MNE_DATASETS_(dataset)_PATH`` is used. If it doesn't exist, the
"~/mne_data" directory is used. If the dataset
is not found under the given path, the data
will be automatically downloaded to the specified folder.
force_update : bool
Force update of the dataset even if a local copy exists.
update_path : bool | None **Deprecated**
If True, set the MNE_DATASETS_(dataset)_PATH in mne-python
config to the given path. If None, the user is prompted.
verbose : bool, str, int, or None
If not None, override default verbose level
(see :func:`mne.verbose`).
Returns
-------
path : list of str
Local path to the given data file. This path is contained inside a
list of length one, for compatibility.
""" # noqa: E501
pass