Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #206 from delira-dev/remove_trixi2
Browse files Browse the repository at this point in the history
Remove trixi
  • Loading branch information
justusschock committed Sep 11, 2019
2 parents 488bc88 + 1d3d272 commit 4801709
Show file tree
Hide file tree
Showing 27 changed files with 1,930 additions and 717 deletions.
1 change: 0 additions & 1 deletion delira/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from delira.training.parameters import Parameters
from delira.training.base_experiment import BaseExperiment
from delira.training.base_trainer import BaseNetworkTrainer
from delira.training.predictor import Predictor
Expand Down
14 changes: 7 additions & 7 deletions delira/training/backends/chainer/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from delira.models.backends.chainer import AbstractChainerNetwork
from delira.data_loading import BaseDataManager
from delira.training.base_experiment import BaseExperiment
from delira.training.parameters import Parameters
from delira.utils import DeliraConfig

from delira.training.backends.chainer.utils import create_optims_default
from delira.training.backends.chainer.utils import convert_to_numpy
Expand All @@ -13,7 +13,7 @@

class ChainerExperiment(BaseExperiment):
def __init__(self,
params: typing.Union[str, Parameters],
config: typing.Union[str, DeliraConfig],
model_cls: AbstractChainerNetwork,
n_epochs=None,
name=None,
Expand All @@ -28,10 +28,10 @@ def __init__(self,
Parameters
----------
params : :class:`Parameters` or str
the training parameters, if string is passed,
it is treated as a path to a pickle file, where the
parameters are loaded from
config : :class:`DeliraConfig` or str
the training config, if string is passed,
it is treated as a path to a file, where the
config is loaded from
model_cls : Subclass of :class:`AbstractChainerNetwork`
the class implementing the model to train
n_epochs : int or None
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self,

if key_mapping is None:
key_mapping = {"x": "data"}
super().__init__(params=params, model_cls=model_cls,
super().__init__(config=config, model_cls=model_cls,
n_epochs=n_epochs, name=name, save_path=save_path,
key_mapping=key_mapping,
val_score_key=val_score_key,
Expand Down
35 changes: 17 additions & 18 deletions delira/training/backends/sklearn/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from delira.models.backends.sklearn import SklearnEstimator

from delira.training.base_experiment import BaseExperiment
from delira.training.parameters import Parameters
from delira.utils import DeliraConfig

from delira.training.backends.sklearn.trainer import SklearnEstimatorTrainer


class SklearnExperiment(BaseExperiment):
def __init__(self,
params: typing.Union[str, Parameters],
config: typing.Union[str, DeliraConfig],
model_cls: BaseEstimator,
n_epochs=None,
name=None,
Expand All @@ -29,10 +29,10 @@ def __init__(self,
Parameters
----------
params : :class:`Parameters` or str
the training parameters, if string is passed,
it is treated as a path to a pickle file, where the
parameters are loaded from
config : :class:`DeliraConfig` or str
the training config, if string is passed,
it is treated as a path to a file, where the
config is loaded from
model_cls : Subclass of :class:`sklearn.base.BaseEstimator`
the class implementing the model to train (will be wrapped by
:class:`SkLearnEstimator`)
Expand Down Expand Up @@ -65,7 +65,7 @@ class wrapping the actual sklearn model to provide delira
if key_mapping is None:
key_mapping = {"X": "X"}

super().__init__(params=params,
super().__init__(config=config,
model_cls=model_cls,
n_epochs=n_epochs,
name=name,
Expand All @@ -77,14 +77,14 @@ class wrapping the actual sklearn model to provide delira
**kwargs)
self._model_wrapper_cls = model_wrapper_cls

def _setup_training(self, params, **kwargs):
def _setup_training(self, config, **kwargs):
"""
Handles the setup for training case
Parameters
----------
params : :class:`Parameters`
the parameters containing the model and training kwargs
config : :class:`DeliraConfig`
the config containing the model and training kwargs
**kwargs :
additional keyword arguments
Expand All @@ -93,14 +93,13 @@ def _setup_training(self, params, **kwargs):
:class:`BaseNetworkTrainer`
the created trainer
"""
model_params = params.permute_training_on_top().model

model_kwargs = {**model_params.fixed, **model_params.variable}
model_kwargs = config.model_params
model_kwargs = {**model_kwargs["variable"], **model_kwargs["fixed"]}

_model = self.model_cls(**model_kwargs)
model = self._model_wrapper_cls(_model)

training_params = params.permute_training_on_top().training
training_params = config.training_params
train_metrics = training_params.nested_get("train_metrics", {})
val_metrics = training_params.nested_get("val_metrics", {})

Expand All @@ -120,14 +119,14 @@ def _setup_training(self, params, **kwargs):
**kwargs
)

def _setup_test(self, params, model, convert_batch_to_npy_fn,
def _setup_test(self, config, model, convert_batch_to_npy_fn,
prepare_batch_fn, **kwargs):
"""
Parameters
----------
params : :class:`Parameters`
the parameters containing the model and training kwargs
config : :class:`DeliraConfig`
the config containing the model and training kwargs
(ignored here, just passed for subclassing and unified API)
model : :class:`sklearn.base.BaseEstimator`
the model to test
Expand Down Expand Up @@ -155,5 +154,5 @@ def _setup_test(self, params, model, convert_batch_to_npy_fn,
input_device="cpu",
output_device="cpu")

return super()._setup_test(params, model, convert_batch_to_npy_fn,
return super()._setup_test(config, model, convert_batch_to_npy_fn,
prepare_batch_fn, **kwargs)
30 changes: 15 additions & 15 deletions delira/training/backends/tf_eager/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from delira.models.backends.tf_eager import AbstractTfEagerNetwork

from delira.training.base_experiment import BaseExperiment
from delira.training.parameters import Parameters
from delira.utils import DeliraConfig

from delira.training.backends.tf_eager.trainer import TfEagerNetworkTrainer
from delira.training.backends.tf_eager.utils import create_optims_default
Expand All @@ -16,7 +16,7 @@

class TfEagerExperiment(BaseExperiment):
def __init__(self,
params: typing.Union[str, Parameters],
config: typing.Union[str, DeliraConfig],
model_cls: AbstractTfEagerNetwork,
n_epochs=None,
name=None,
Expand All @@ -31,10 +31,10 @@ def __init__(self,
Parameters
----------
params : :class:`Parameters` or str
the training parameters, if string is passed,
it is treated as a path to a pickle file, where the
parameters are loaded from
config : :class:`DeliraConfig` or str
the training config, if string is passed,
it is treated as a path to a file, where the
config is loaded from
model_cls : Subclass of :class:`AbstractTfEagerNetwork`
the class implementing the model to train
n_epochs : int or None
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self,

if key_mapping is None:
key_mapping = {"x": "data"}
super().__init__(params=params, model_cls=model_cls,
super().__init__(config=config, model_cls=model_cls,
n_epochs=n_epochs, name=name, save_path=save_path,
key_mapping=key_mapping,
val_score_key=val_score_key,
Expand All @@ -84,7 +84,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
num_splits=None, shuffle=False, random_seed=None,
split_type="random", val_split=0.2, label_key="label",
train_kwargs: dict = None, test_kwargs: dict = None,
metric_keys: dict = None, params=None, verbose=False,
metric_keys: dict = None, config=None, verbose=False,
**kwargs):
"""
Performs a k-Fold cross-validation
Expand All @@ -98,7 +98,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
dictionary containing the metrics to evaluate during k-fold
num_epochs : int or None
number of epochs to train (if not given, will either be
extracted from ``params``, ``self.parms`` or ``self.n_epochs``)
extracted from ``config``, ``self.config`` or ``self.n_epochs``)
num_splits : int or None
the number of splits to extract from ``data``.
If None: uses a default of 10
Expand Down Expand Up @@ -133,9 +133,9 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
kwargs to update the behavior of the :class:`BaseDataManager`
containing the test and validation data.
If None: empty dict will be passed
params : :class:`Parameters`or None
config : :class:`DeliraConfig` or None
the training and model parameters
(will be merged with ``self.params``)
(will be merged with ``self.config``)
verbose : bool
verbosity
**kwargs :
Expand Down Expand Up @@ -196,7 +196,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
train_kwargs=train_kwargs,
test_kwargs=test_kwargs,
metric_keys=metric_keys,
params=params,
config=config,
verbose=verbose,
**kwargs)

Expand Down Expand Up @@ -261,14 +261,14 @@ def test(self, network, test_data: BaseDataManager,
verbose=verbose, prepare_batch=prepare_batch,
convert_fn=convert_fn, **kwargs)

def setup(self, params, training=True, **kwargs):
def setup(self, config, training=True, **kwargs):
"""
Defines the setup behavior (model, trainer etc.) for training and
testing case
Parameters
----------
params : :class:`Parameters`
config : :class:`DeliraConfig`
the parameters to use for setup
training : bool
whether to setup for training case or for testing case
Expand All @@ -291,5 +291,5 @@ def setup(self, params, training=True, **kwargs):
"""
tf.reset_default_graph()
return super().setup(params=params, training=training,
return super().setup(config=config, training=training,
**kwargs)
14 changes: 7 additions & 7 deletions delira/training/backends/tf_graph/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from delira.models.backends.tf_graph import AbstractTfGraphNetwork
from delira.data_loading import BaseDataManager

from delira.training.parameters import Parameters
from delira.utils import DeliraConfig
from delira.training.backends.tf_eager.experiment import TfEagerExperiment
from delira.training.backends.tf_eager.utils import create_optims_default

Expand All @@ -16,7 +16,7 @@

class TfGraphExperiment(TfEagerExperiment):
def __init__(self,
params: typing.Union[str, Parameters],
config: typing.Union[str, DeliraConfig],
model_cls: AbstractTfGraphNetwork,
n_epochs=None,
name=None,
Expand All @@ -31,10 +31,10 @@ def __init__(self,
Parameters
----------
params : :class:`Parameters` or str
the training parameters, if string is passed,
it is treated as a path to a pickle file, where the
parameters are loaded from
config : :class:`DeliraConfig` or str
the training config, if string is passed,
it is treated as a path to a file, where the
config is loaded from
model_cls : Subclass of :class:`AbstractTfEagerNetwork`
the class implementing the model to train
n_epochs : int or None
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self,
key_mapping = {"data": "data"}

super().__init__(
params=params,
config=config,
model_cls=model_cls,
n_epochs=n_epochs,
name=name,
Expand Down
24 changes: 12 additions & 12 deletions delira/training/backends/torch/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from delira.data_loading import BaseDataManager

from delira.training.base_experiment import BaseExperiment
from delira.training.parameters import Parameters
from delira.utils import DeliraConfig

from delira.training.backends.torch.trainer import PyTorchNetworkTrainer
from delira.training.backends.torch.utils import create_optims_default
Expand All @@ -16,7 +16,7 @@

class PyTorchExperiment(BaseExperiment):
def __init__(self,
params: typing.Union[str, Parameters],
config: typing.Union[str, DeliraConfig],
model_cls: AbstractPyTorchNetwork,
n_epochs=None,
name=None,
Expand All @@ -31,10 +31,10 @@ def __init__(self,
Parameters
----------
params : :class:`Parameters` or str
the training parameters, if string is passed,
it is treated as a path to a pickle file, where the
parameters are loaded from
config : :class:`DeliraConfig` or str
the training config, if string is passed,
it is treated as a path to a file, where the
config is loaded from
model_cls : Subclass of :class:`AbstractPyTorchNetwork`
the class implementing the model to train
n_epochs : int or None
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self,

if key_mapping is None:
key_mapping = {"x": "data"}
super().__init__(params=params, model_cls=model_cls,
super().__init__(config=config, model_cls=model_cls,
n_epochs=n_epochs, name=name, save_path=save_path,
key_mapping=key_mapping,
val_score_key=val_score_key,
Expand All @@ -83,7 +83,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
num_splits=None, shuffle=False, random_seed=None,
split_type="random", val_split=0.2, label_key="label",
train_kwargs: dict = None, test_kwargs: dict = None,
metric_keys: dict = None, params=None, verbose=False,
metric_keys: dict = None, config=None, verbose=False,
**kwargs):
"""
Performs a k-Fold cross-validation
Expand All @@ -97,7 +97,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
dictionary containing the metrics to evaluate during k-fold
num_epochs : int or None
number of epochs to train (if not given, will either be
extracted from ``params``, ``self.parms`` or ``self.n_epochs``)
extracted from ``config``, ``self.config`` or ``self.n_epochs``)
num_splits : int or None
the number of splits to extract from ``data``.
If None: uses a default of 10
Expand Down Expand Up @@ -132,9 +132,9 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
kwargs to update the behavior of the :class:`BaseDataManager`
containing the test and validation data.
If None: empty dict will be passed
params : :class:`Parameters`or None
config : :class:`Parameters`or None
the training and model parameters
(will be merged with ``self.params``)
(will be merged with ``self.config``)
verbose : bool
verbosity
**kwargs :
Expand Down Expand Up @@ -195,7 +195,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
train_kwargs=train_kwargs,
test_kwargs=test_kwargs,
metric_keys=metric_keys,
params=params,
config=config,
verbose=verbose,
**kwargs)

Expand Down

0 comments on commit 4801709

Please sign in to comment.