Skip to content

Commit

Permalink
CLN simplify model API+add more documentation (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomMoral committed Aug 7, 2023
1 parent c52df51 commit 339104a
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 235 deletions.
66 changes: 39 additions & 27 deletions benchmark_utils/augmented_dataset.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,53 @@
from benchopt import BaseSolver, safe_import_context
from benchopt import BaseSolver
from abc import abstractmethod, ABC

from benchmark_utils.transformation import (
channels_dropout,
smooth_timemask,
)

# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.
with safe_import_context() as import_ctx:
from skorch.helper import to_numpy
from benchmark_utils.transformation import channels_dropout, smooth_timemask


class AugmentedBCISolver(BaseSolver, ABC):
"""Base class for solvers that use augmented data.
This class implements some basic methods from another methods ihnerited.
This class implements some basic methods to run methods with various
augmentation levels.
"""

parameters = {
"augmentation": [
"SmoothTimeMask",
"ChannelsDropout",
"IdentityTransform",
],
}

@abstractmethod
def set_objective(self, **objective_dict):
pass
"""Set the objective information from Objective.get_objective.
@property
def name(self):
Objective
---------
X: training data for the model
y: training labels to train the model.
sfreq: sampling frequency to allow filtering the data.
"""
pass

def run(self, n_iter):
"""Run the solver to evaluate it for a given number of iterations."""
def run(self, n_augmentation):
"""Run the solver to evaluate it for a given number of augmentation.
With this dataset, we consider that the performance curve is sampled
for various number of augmentation applied to the dataset.
"""
if self.augmentation == "ChannelsDropout":
X, y = channels_dropout(self.X, self.y, n_augmentation=n_iter)
X, y = channels_dropout(
self.X, self.y, n_augmentation=n_augmentation
)

elif self.augmentation == "SmoothTimeMask":
X, y = smooth_timemask(
self.X, self.y, n_augmentation=n_iter, sfreq=self.sfreq
self.X, self.y, n_augmentation=n_augmentation, sfreq=self.sfreq
)
else:
X = to_numpy(self.X)
X = self.X
y = self.y

self.clf.fit(X, y)
Expand All @@ -45,11 +56,12 @@ def get_next(self, n_iter):
return n_iter + 1

def get_result(self):
# Return the result from one optimization run.
# The outputs of this function are the arguments of `Objective.compute`
# This defines the benchmark's API for solvers' results.
# it is customizable for each benchmark.
return self.clf
"""Return the model to `Objective.evaluate_result`.
def warmup_solver(self):
pass
Result
------
model: an instance of a fitted model.
This model should have methods `score` and `predict`, that accept
braindecode.WindowsDataset as input.
"""
return dict(model=self.clf)
3 changes: 1 addition & 2 deletions benchmark_utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def windows_data(
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this.
# It needs parameters to define how
# trials should be used.
# It needs parameters to define how trials should be used.
windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
Expand Down
24 changes: 12 additions & 12 deletions datasets/BNCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ class Dataset(BaseDataset):

# Name to select the dataset in the CLI and to display the results.
name = "BNCI"
parameters = {'paradigm_name': ('MotorImagery', 'LeftRightImagery')}
# List of parameters to generate the datasets. The benchmark will consider
# the cross product for each key in the dictionary.
# Any parameters 'param' defined here is available as `self.param`.
parameters = {
'paradigm_name': ('MotorImagery', 'LeftRightImagery')
}

def get_data(self):
"""Returns the data to be passed to Objective.set_data.
Data
----
Dataset: an instance of a braindecode.WindowsDataset
sfreq: the sampling frequency of the data.
"""

# The return arguments of this function are passed as keyword arguments
# to `Objective.set_data`. This defines the benchmark's
# API to pass data. It is customizable for each benchmark.
dataset_name = "BNCI2014001"
data = MOABBDataset(dataset_name=dataset_name,
subject_ids=None)
data = MOABBDataset(dataset_name=dataset_name, subject_ids=None)

dataset, sfreq = windows_data(data, self.paradigm_name)

return dict(dataset=dataset,
paradigm_name=self.paradigm_name,
sfreq=sfreq)
return dict(dataset=dataset, sfreq=sfreq)
30 changes: 11 additions & 19 deletions datasets/Zhou.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
from benchopt import BaseDataset, safe_import_context


# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.
with safe_import_context() as import_ctx:
from braindecode.datasets import MOABBDataset
from benchmark_utils import windows_data


# All datasets must be named `Dataset` and inherit from `BaseDataset`

class Dataset(BaseDataset):

name = "Zhou"

parameters = {'paradigm_name': ('LeftRightImagery', 'MotorImagery')}
# List of parameters to generate the datasets. The benchmark will consider
# the cross product for each key in the dictionary.
# Any parameters 'param' defined here is available as `self.param`.

def get_data(self):
"""Returns the data to be passed to Objective.set_data.
Data
----
Dataset: an instance of a braindecode.WindowsDataset
sfreq: the sampling frequency of the data.
"""

# The return arguments of this function are passed as keyword arguments
# to `Objective.set_data`. This defines the benchmark's
# API to pass data. It is customizable for each benchmark.
dataset_name = "Zhou2016"
data = MOABBDataset(dataset_name=dataset_name,
subject_ids=None)
paradigm_name = "LeftRightImagery"
data = MOABBDataset(dataset_name=dataset_name, subject_ids=None)

dataset, sfreq = windows_data(data, "LeftRightImagery")
dataset, sfreq = windows_data(data, paradigm_name)

return dict(dataset=dataset,
paradigm_name="LeftRightImagery",
sfreq=sfreq)
return dict(dataset=dataset, sfreq=sfreq)
26 changes: 13 additions & 13 deletions datasets/Simulated.py → datasets/simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,30 @@

class Dataset(BaseDataset):
# Name to select the dataset in the CLI and to display the results.
name = "simulated"
# List of parameters to generate the datasets. The benchmark will consider
# the cross product for each key in the dictionary.
# Any parameters 'param' defined here is available as `self.param`.
name = "Simulated"

def get_data(self):
# The return arguments of this function are passed as keyword arguments
# to `Objective.set_data`. This defines the benchmark's
# API to pass data. It is customizable for each benchmark.
"""Returns the data to be passed to Objective.set_data.
Data
----
Dataset: an instance of a braindecode.WindowsDataset
sfreq: the sampling frequency of the data.
"""

dataset_name = "FakeDataset"
paradigm_name = "LeftRightImagery"
data = MOABBDataset(
dataset_name=dataset_name,
subject_ids=None,
dataset_kwargs={
"event_list": ["left_hand", "right_hand"],
"paradigm": "imagery",
"n_subjects": 2
},
)

dataset, sfreq = windows_data(data, "LeftRightImagery")

self.sfreq = sfreq
dataset, sfreq = windows_data(data, paradigm_name)
dataset = dataset.split([0])['0']

return dict(
dataset=dataset, paradigm_name="LeftRightImagery", sfreq=sfreq
)
return dict(dataset=dataset, sfreq=sfreq)
109 changes: 67 additions & 42 deletions objective.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,53 @@
from benchopt import BaseObjective, safe_import_context

# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.

with safe_import_context() as import_ctx:
from numpy import array

from sklearn.dummy import DummyClassifier
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import FunctionTransformer

from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score as BAS

from skorch.helper import SliceDataset, to_numpy
from benchmark_utils.dataset import split_windows_train_test
# The benchmark objective must be named `Objective` and
# inherit from `BaseObjective` for `benchopt` to work properly.


class Objective(BaseObjective):

# Name to select the objective in the CLI and to display the results.
name = "BCI"

# List of parameters for the objective. The benchmark will consider
# the cross product for each key in the dictionary.
# All parameters 'p' defined here are available as 'self.p'.
name = "Brain-Computer Interface"

link = 'pip: git+https://github.com/Neurotechx/moabb@develop#egg=moabb'
intall_cmd = 'conda'
requirements = [link,
'scikit-learn']
requirements = [
'scikit-learn',
'pytorch:pytorch',
'pip:braindecode',
'pip:git+https://github.com/Neurotechx/moabb@develop#egg=moabb',
]

parameters = {
'evaluation_process, subject, subject_test, session_test': [
('intra_subject', 1, None, None),
],
}
# The solvers will train on all the subject except subject_test.
# It will be the same for the sessions.

is_convex = False

# Minimal version of benchopt required to run this benchmark.
# Bump it up if the benchmark depends on a new feature of benchopt.
min_benchopt_version = "1.3.2"
min_benchopt_version = "1.4.1"

def set_data(self, dataset, sfreq):
"""Set the data retrieved from Dataset.get_data.
def set_data(self, dataset, paradigm_name, sfreq):
# The keyword arguments of this function are the keys of the dictionary
# returned by `Dataset.get_data`. This defines the benchmark's
# API to pass data. This is customizable for each benchmark.
Data
----
Dataset: an instance of a braindecode.WindowsDataset
sfreq: the sampling frequency of the data.
"""

data_split_subject = dataset.split('subject')

Expand Down Expand Up @@ -105,36 +108,58 @@ def set_data(self, dataset, paradigm_name, sfreq):
sfreq=sfreq,
)

def compute(self, model):
# The arguments of this function are the outputs of the
# `Solver.get_result`. This defines the benchmark's API to pass
# solvers' result. This is customizable for each benchmark.
if not type(model) == 'braindecode.classifier.EEGClassifier':
self.X_train = to_numpy(self.X_train)
self.X_test = to_numpy(self.X_test)
def evaluate_result(self, model):
"""Compute the evaluation metrics for the benchmark.
Result
------
model: an instance of a fitted model.
This model should have methods `score` and `predict`, that accept
braindecode.WindowsDataset as input.
Metrics
-------
score_test: accuracy on the testing set.
score_train: accuracy on the training set.
balanced_accuracy: balanced accuracy on the testing set
value: error on the testing set.
"""

score_train = model.score(self.X_train, self.y_train)
score_test = model.score(self.X_test, self.y_test)
bl_acc = BAS(self.y_test, model.predict(self.X_test))

# This method can return many metrics in a dictionary. One of these
# metrics needs to be `value` for convergence detection purposes.
return dict(score_test=score_test,
value=-score_test,
score_train=score_train,
balanced_accuracy=bl_acc)
return dict(
score_test=score_test,
score_train=score_train,
balanced_accuracy=bl_acc,
value=1-score_test,
)

def get_one_solution(self):
# Return one solution. The return value should be an object compatible
# with `self.compute`. This is mainly for testing purposes.
return DummyClassifier().fit(self.X_train, self.y_train)
def get_one_result(self):
"""Return one dummy result.
Result
------
model: an instance of a fitted model.
This model should have methods `score` and `predict`, that accept
braindecode.WindowsDataset as input.
"""
clf = make_pipeline(
FunctionTransformer(to_numpy),
DummyClassifier()
)
return dict(model=clf.fit(self.X_train, self.y_train))

def get_objective(self):
# Define the information to pass to each solver to run the benchmark.
# The output of this function are the keyword arguments
# for `Solver.set_objective`. This defines the
# benchmark's API for passing the objective to the solver.
# It is customizable for each benchmark.
"""Pass the objective information to Solvers.set_objective.
Objective
---------
X: training data for the model
y: training labels to train the model.
sfreq: sampling frequency to allow filtering the data.
"""

return dict(
X=self.X_train,
Expand Down

0 comments on commit 339104a

Please sign in to comment.