Skip to content

Commit

Permalink
[ENH] Register model (#159)
Browse files Browse the repository at this point in the history
* ADD model registration

* codespell

* ADD test

* ADD lastest for #105

* Bug fix scoring using dict with registered scorer

* Add overwrite to register_models

* Adjust transformer to make overwrite consistent

* Make registering more consistent

* flake8

* latest

* codespell

* Correct Docs
  • Loading branch information
samihamdan committed Jun 17, 2022
1 parent 3a3fd9a commit 11911cc
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 54 deletions.
7 changes: 6 additions & 1 deletion docs/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Enhancements

- Add `CBPM` transformer (by `Sami Hamdan`_).

- ADD `register_model` (:gh:`105` by `Sami Hamdan`_).

- Add documentation/example for parallelization (by `Sami Hamdan`_).

Bugs
Expand All @@ -58,10 +60,13 @@ Bugs

- Fix Bug in the transformer wrapper implementation (:gh:`122` by `Sami Hamdan`_).

- Fix Bug Target Transformer missing BaseEstimator (:gh:`151` by `Sami Hamdan`_).
- Fix Bug registered scorer not working in dictionary for scoring ( by `Sami Hamdan`_).

API changes
~~~~~~~~~~~
- Make api surrounding registering consistently use overwrite (by `Sami Hamdan`_).

- Fix Bug Target Transformer missing BaseEstimator (:gh:`151` by `Sami Hamdan`_).


- Inner `cv` needs to be provided using `search_params`. Deprecating `cv` in `model_params` (:gh:`146` by `Sami Hamdan`_).
Expand Down
2 changes: 1 addition & 1 deletion examples/advanced/run_custom_scorers_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def pearson_scorer(y_true, y_pred):
###############################################################################
# Before using it, we need to convert it to a sklearn scorer and register it
# with julearn.
register_scorer(name='pearsonr', scorer=make_scorer(pearson_scorer))
register_scorer(scorer_name='pearsonr', scorer=make_scorer(pearson_scorer))

###############################################################################
# Now we can use it as another scoring metric.
Expand Down
3 changes: 0 additions & 3 deletions julearn/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def run_cross_validation(
logger.info(f'Using default CV')
cv = 'repeat:5_nfolds:5'

# if scoring is None:
# scoring = 'accuracy'

# Interpret the input data and prepare it to be used with the library
df_X_conf, y, df_groups, _ = prepare_input_data(
X=X, y=y, confounds=confounds, df=data, pos_labels=pos_labels,
Expand Down
3 changes: 2 additions & 1 deletion julearn/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# Sami Hamdan <s.hamdan@fz-juelich.de>
# License: AGPL
from . available_models import list_models, get_model
from . available_models import (list_models, get_model,
register_model, reset_model_register)
92 changes: 91 additions & 1 deletion julearn/estimators/available_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# Sami Hamdan <s.hamdan@fz-juelich.de>
# License: AGPL
from copy import deepcopy
from sklearn.svm import SVC, SVR
from sklearn.ensemble import (RandomForestClassifier, RandomForestRegressor,
ExtraTreesClassifier, ExtraTreesRegressor,
Expand All @@ -18,7 +19,7 @@
GaussianNB, MultinomialNB)
from sklearn.dummy import DummyClassifier, DummyRegressor

from .. utils import raise_error
from .. utils import raise_error, warn, logger
from . dynamic import DynamicSelection

_available_models = {
Expand Down Expand Up @@ -114,6 +115,8 @@
}
}

_available_models_reset = deepcopy(_available_models)


def list_models():
"""List all the available model names
Expand Down Expand Up @@ -156,3 +159,90 @@ def get_model(name, problem_type, **kwargs):

out = _available_models[name][problem_type](**kwargs)
return out


def register_model(model_name,
binary_cls=None, multiclass_cls=None, regression_cls=None,
overwrite=None
):
"""Register a model to julearn.
This function allows you to add a model or models for different
problem_types to julearn.
Afterwards, it behaves like every other julearn model and can
be referred to by name. E.g. you can use inside of
`run_cross_validation` unsig `model=model_name`.
Parameters
----------
model_name : str
Name by which model will be referenced by
binary_cls : object
The class which will be used for
binary_classification problem_type.
multiclass_cls : str
The class which will be used for
multiclass_classification problem_type.
regression_cls : str
The class which will be used for
regression problem_type.
overwrite : bool | None, optional
decides whether overwrite should be allowed, by default None.
Options are:
* None : overwrite is possible, but warns the user
* True : overwrite is possible without any warning
* False : overwrite is not possible, error is raised instead
"""
problem_types = [
"binary_classification",
"multiclass_classification",
"regression"
]
for cls, problem_type in zip(
[binary_cls, multiclass_cls, regression_cls],
problem_types):
if cls is not None:
if _available_models.get(model_name) is not None:
if _available_models.get(model_name).get(problem_type):
if overwrite is None:
warn(
f'Model named {model_name} with'
' problem type {problem_type}'
' already exists. '
f'Therefore, {model_name} will be overwritten. '
'To remove this warning set overwrite=True. '
"If you won't to reset this use "
'`julearn.estimators.reset_model_register`.'
)
elif overwrite is False:
raise_error(

f'Model named {model_name} with '
'problem type {problem_type}'
' already exists. '
f'Therefore, {model_name} will be overwritten. '
'overwrite is set to False, '
'therefore you cannot overwrite '
'existing models. Set overwrite=True'
' in case you want to '
'overwrite existing models'
)

logger.info(f'registering model named {model_name} '
f'with problem_type {problem_type}'
)

_available_models[model_name][problem_type] = cls
else:

logger.info(f'registering model named {model_name} '
f'with problem_type {problem_type}'
)
_available_models[model_name] = {problem_type: cls}


def reset_model_register():
global _available_models
_available_models = deepcopy(_available_models_reset)
return _available_models
45 changes: 45 additions & 0 deletions julearn/estimators/tests/test_available_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# Sami Hamdan <s.hamdan@fz-juelich.de>
# Shammi More <s.more@fz-juelich.de>
# License: AGPL

import pytest
from julearn.estimators import register_model, reset_model_register, get_model
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor


def test_register_model():
register_model("dt",
binary_cls=DecisionTreeClassifier,
multiclass_cls=DecisionTreeClassifier,
regression_cls=DecisionTreeRegressor
)
binary = get_model("dt", "binary_classification")
multiclass = get_model("dt", "multiclass_classification")
regression = get_model("dt", "regression")

assert isinstance(binary, DecisionTreeClassifier)
assert isinstance(multiclass, DecisionTreeClassifier)
assert isinstance(regression, DecisionTreeRegressor)
reset_model_register()

with pytest.raises(ValueError, match="The specified model "):
binary = get_model("dt", "binary_classification")


def test_register_warning():
with pytest.warns(RuntimeWarning, match="Model name"):
register_model("rf", regression_cls=RandomForestRegressor)
reset_model_register()

with pytest.raises(ValueError, match="Model name"):
register_model(
"rf", regression_cls=RandomForestRegressor, overwrite=False)
reset_model_register()

with pytest.warns(None) as record:
register_model(
"rf", regression_cls=RandomForestRegressor, overwrite=True)
reset_model_register()
assert len(record) == 0
16 changes: 8 additions & 8 deletions julearn/model_selection/available_searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def get_searcher(name):
return out


def register_searcher(name, searcher, overwrite=None):
def register_searcher(searcher_name, searcher, overwrite=None):
"""Register searcher to julearn.
This function allows you to add a scikit-learn compatible searching
algorithm to julearn. Afterwars, you can call it as all other searchers in
julearn.
Parameters
----------
name : str
searcher_name : str
Name by which the searcher will be referenced by.
searcher : obj
The searcher class by which the searcher can be initialized.
Expand All @@ -67,23 +67,23 @@ def register_searcher(name, searcher, overwrite=None):
* False : overwrite is not possible, error is raised instead
"""
if name in list_searchers():
if searcher_name in list_searchers():
if overwrite is None:
warn(
f'searcher named {name} already exists. '
f'Therefore, {name} will be overwritten. '
f'searcher named {searcher_name} already exists. '
f'Therefore, {searcher_name} will be overwritten. '
'To remove this warning set `overwrite=True`. '
)
elif overwrite is False:
raise_error(
f'searcher named {name} already exists and '
f'searcher named {searcher_name} already exists and '
'overwrite is set to False, therefore you cannot overwrite '
'existing searchers. '
'Set `overwrite=True` in case you want to '
'overwrite existing searchers.'
)
logger.info(f'Registering new searcher: {name}')
_available_searchers[name] = searcher
logger.info(f'Registering new searcher: {searcher_name}')
_available_searchers[searcher_name] = searcher


def reset_searcher_register():
Expand Down
5 changes: 4 additions & 1 deletion julearn/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,10 @@ def prepare_scoring(estimator, scorers):
if isinstance(scorers, list):
scoring = {k: get_extended_scorer(estimator, k) for k in scorers}
elif isinstance(scorers, dict):
scoring = scorers
scoring = {
name: get_extended_scorer(estimator, scorer) if isinstance(
scorer, str) else scorer
for name, scorer in scorers.items()}
else:
scoring = get_extended_scorer(estimator, scorers)
return scoring
Expand Down
18 changes: 9 additions & 9 deletions julearn/scoring/available_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def list_scorers():
return {**SCORERS, **_extra_available_scorers}.keys()


def register_scorer(name, scorer, overwrite=None):
def register_scorer(scorer_name, scorer, overwrite=None):
"""register a scorer, so that you can access it in scoring with its name.
Parameters
----------
name : str
scorer_name : str
name of the scorer you want to register
scorer : callable
function of signature (estimator, X, y) see:
Expand All @@ -65,22 +65,22 @@ def register_scorer(name, scorer, overwrite=None):
* True : overwrite is possible without any warning
* False : overwrite is not possible, error is raised instead
"""
if name in list_scorers():
if scorer_name in list_scorers():
if overwrite is None:
warn(
f'scorer named {name} already exists. '
f'Therefore, {name} will be overwritten. '
f'scorer named {scorer_name} already exists. '
f'Therefore, {scorer_name} will be overwritten. '
'To remove this warning set overwrite=True '
)
logger.info(f'registering scorer named {name}')
logger.info(f'registering scorer named {scorer_name}')
elif overwrite is False:
raise_error(
f'scorer named {name} already exists and '
f'scorer named {scorer_name} already exists and '
'overwrite is set to False, therefore you cannot overwrite '
'existing scorers. Set overwrite=True in case you want to '
'overwrite existing scorers')
logger.info(f'registering scorer named {name}')
_extra_available_scorers[name] = scorer
logger.info(f'registering scorer named {scorer_name}')
_extra_available_scorers[scorer_name] = scorer


def reset_scorer_register():
Expand Down
2 changes: 1 addition & 1 deletion julearn/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from . import confounds
from . import target
from . available_transformers import (
list_transformers, get_transformer, register_transformer, reset_register)
list_transformers, get_transformer, register_transformer, reset_transformer_register)

from . confounds import DataFrameConfoundRemover, TargetConfoundRemover
from . meta import DataFrameWrapTransformer
Expand Down

0 comments on commit 11911cc

Please sign in to comment.