Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #201 from delira-dev/callbacks
Browse files Browse the repository at this point in the history
Add training callbacks
  • Loading branch information
mibaumgartner committed Sep 11, 2019
2 parents 4801709 + 4f69887 commit b491b62
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 131 deletions.
7 changes: 5 additions & 2 deletions delira/training/backends/chainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,14 @@ def _at_training_begin(self, *args, **kwargs):
keyword arguments
"""
for cbck in self._callbacks:
self._update_state(cbck.at_training_begin(self, *args, **kwargs))

self.save_state(os.path.join(
self.save_path, "checkpoint_epoch_%d" % self.start_epoch),
self.start_epoch)

def _at_training_end(self):
def _at_training_end(self, *args, **kwargs):
"""
Defines Behaviour at end of training: Loads best model if
available
Expand All @@ -316,7 +319,7 @@ def _at_training_end(self):
self.update_state(os.path.join(self.save_path,
'checkpoint_best.chain'))

return self.module
return super()._at_training_end(*args, **kwargs)

def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
**kwargs):
Expand Down
144 changes: 63 additions & 81 deletions delira/training/backends/sklearn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,73 @@ def _at_training_begin(self, *args, **kwargs):
keyword arguments
"""
for cbck in self._callbacks:
self._update_state(cbck.at_training_begin(self, *args, **kwargs))

self.save_state(os.path.join(
self.save_path, "checkpoint_epoch_%d" % self.start_epoch),
self.start_epoch)

def _at_training_end(self, *args, **kwargs):
"""
Defines Behaviour at end of training: Loads best model if
available
Returns
-------
:class:`SkLearnEstimator`
best network
"""
if os.path.isfile(os.path.join(self.save_path,
'checkpoint_best.pkl')):

# load best model and return it
self.update_state(os.path.join(self.save_path,
'checkpoint_best.pkl'))

return super()._at_training_end(*args, **kwargs)

def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
**kwargs):
"""
Defines behaviour at beginning of each epoch: Executes all callbacks's
`at_epoch_end` method and saves current state if necessary
Parameters
----------
metrics_val : dict
validation metrics
val_score_key : str
validation score key
epoch : int
current epoch
num_epochs : int
total number of epochs
is_best : bool
whether current model is best one so far
**kwargs :
keyword arguments
"""

for cb in self._callbacks:
self._update_state(cb.at_epoch_end(self,
val_metrics=metrics_val,
val_score_key=val_score_key,
curr_epoch=epoch))

if epoch % self.save_freq == 0:
self.save_state(os.path.join(self.save_path,
"checkpoint_epoch_%d.pkl"
% epoch),
epoch)

if is_best:
self.save_state(os.path.join(self.save_path,
"checkpoint_best.pkl"),
epoch)

def _get_classes_if_necessary(self, dmgr: BaseDataManager, verbose,
label_key=None):
"""
Expand Down Expand Up @@ -291,66 +354,6 @@ class collection if necessary
val_score_key, val_score_mode, reduce_mode,
verbose)

def _at_training_end(self):
"""
Defines Behaviour at end of training: Loads best model if
available
Returns
-------
:class:`SkLearnEstimator`
best network
"""
if os.path.isfile(os.path.join(self.save_path,
'checkpoint_best.pkl')):

# load best model and return it
self.update_state(os.path.join(self.save_path,
'checkpoint_best.pkl'))

return self.module

def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
**kwargs):
"""
Defines behaviour at beginning of each epoch: Executes all callbacks's
`at_epoch_end` method and saves current state if necessary
Parameters
----------
metrics_val : dict
validation metrics
val_score_key : str
validation score key
epoch : int
current epoch
num_epochs : int
total number of epochs
is_best : bool
whether current model is best one so far
**kwargs :
keyword arguments
"""

for cb in self._callbacks:
self._update_state(cb.at_epoch_end(self,
val_metrics=metrics_val,
val_score_key=val_score_key,
curr_epoch=epoch))

if epoch % self.save_freq == 0:
self.save_state(os.path.join(self.save_path,
"checkpoint_epoch_%d.pkl"
% epoch),
epoch)

if is_best:
self.save_state(os.path.join(self.save_path,
"checkpoint_best.pkl"),
epoch)

def save_state(self, file_name, epoch, **kwargs):
"""
saves the current state via
Expand Down Expand Up @@ -396,27 +399,6 @@ def load_state(file_name, *args, **kwargs):

return load_checkpoint(file_name, **kwargs)

def update_state(self, file_name, *args, **kwargs):
"""
Update internal state from a loaded state
Parameters
----------
file_name : str
file containing the new state to load
*args :
positional arguments
**kwargs :
keyword arguments
Returns
-------
:class:`SkLearnEstimatorTrainer`
the trainer with a modified state
"""
self._update_state(self.load_state(file_name, *args, **kwargs))

def _update_state(self, new_state):
"""
Update the state from a given new state
Expand Down
4 changes: 2 additions & 2 deletions delira/training/backends/tf_eager/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,
self.update_state(latest_state_path)
self.start_epoch = latest_epoch

def _at_training_end(self):
def _at_training_end(self, *args, **kwargs):
"""
Defines Behaviour at end of training: Loads best model if available
Expand All @@ -245,7 +245,7 @@ def _at_training_end(self):
'checkpoint_best')
)

return self.module
return super()._at_training_end(self, *args, **kwargs)

def _train_single_epoch(self, batchgen, epoch, verbose=False):
"""
Expand Down
4 changes: 2 additions & 2 deletions delira/training/backends/tf_graph/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,
self.update_state(latest_state_path)
self.start_epoch = latest_epoch

def _at_training_end(self):
def _at_training_end(self, *args, **kwargs):
"""
Defines Behaviour at end of training: Loads best model if available
Expand All @@ -251,7 +251,7 @@ def _at_training_end(self):
'checkpoint_best')
)

return self.module
return super()._at_training_end(*args, **kwargs)

def _train_single_epoch(self, batchgen, epoch, verbose=False):
"""
Expand Down
35 changes: 8 additions & 27 deletions delira/training/backends/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,

def _at_training_begin(self, *args, **kwargs):
"""
Defines behaviour at beginning of training
Defines the behaviour at beginnig of the training
Parameters
----------
Expand All @@ -336,11 +336,13 @@ def _at_training_begin(self, *args, **kwargs):
keyword arguments
"""
self.save_state(os.path.join(
self.save_path, "checkpoint_epoch_%d" % self.start_epoch),
self.start_epoch)
for cbck in self._callbacks:
self._update_state(cbck.at_training_begin(self, *args, **kwargs))

def _at_training_end(self):
self.save_state(os.path.join(self.save_path, "checkpoint_epoch_%d"
% self.start_epoch), self.start_epoch)

def _at_training_end(self, *args, **kwargs):
"""
Defines Behaviour at end of training: Loads best model if
available
Expand All @@ -358,7 +360,7 @@ def _at_training_end(self):
self.update_state(os.path.join(self.save_path,
'checkpoint_best.pt'))

return self.module
return super()._at_training_end(*args, **kwargs)

def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
**kwargs):
Expand Down Expand Up @@ -502,27 +504,6 @@ def load_state(file_name, **kwargs):

return load_checkpoint_torch(file_name, **kwargs)

def update_state(self, file_name, *args, **kwargs):
"""
Update internal state from a loaded state
Parameters
----------
file_name : str
file containing the new state to load
*args :
positional arguments
**kwargs :
keyword arguments
Returns
-------
:class:`BaseNetworkTrainer`
the trainer with a modified state
"""
self._update_state(self.load_state(file_name, *args, **kwargs))

def _update_state(self, new_state):
"""
Update the state from a given new state
Expand Down
22 changes: 16 additions & 6 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def _at_training_begin(self, *args, **kwargs):
keyword arguments
"""
for cbck in self._callbacks:
self._update_state(cbck.at_training_begin(self, *args, **kwargs))

self.save_state(os.path.join(self.save_path, "checkpoint_epoch_%d"
% self.start_epoch), self.start_epoch)
% self.start_epoch))

def _at_training_end(self, *args, **kwargs):
"""
Expand All @@ -219,6 +222,9 @@ def _at_training_end(self, *args, **kwargs):
the network with the loaded state
"""
for cbck in self._callbacks:
self._update_state(cbck.at_training_end(self, *args, **kwargs))

return self.module

def _at_epoch_begin(self, metrics_val, val_score_key, epoch, num_epochs,
Expand Down Expand Up @@ -552,11 +558,15 @@ def register_callback(self, callback: AbstractCallback):
"AbstractCallback or provide functions " \
"'at_epoch_begin' and 'at_epoch_end'"
instance_check = isinstance(callback, AbstractCallback)
attr_check_begin = hasattr(callback, "at_epoch_begin")
attr_check_end = hasattr(callback, "at_epoch_end")
attr_check_both = attr_check_begin and attr_check_end

assert instance_check or attr_check_both, assertion_str
attr_check_begin_epoch = hasattr(callback, "at_epoch_begin")
attr_check_end_epoch = hasattr(callback, "at_epoch_end")
attr_check_both_epoch = attr_check_begin_epoch and attr_check_end_epoch
attr_check_begin_train = hasattr(callback, "at_training_begin")
attr_check_end_train = hasattr(callback, "at_training_end")
attr_check_both_train = attr_check_begin_train and attr_check_end_train
attr_check_all = attr_check_both_epoch and attr_check_both_train

assert instance_check or attr_check_all, assertion_str

self._callbacks.append(callback)

Expand Down
46 changes: 46 additions & 0 deletions delira/training/callbacks/abstract_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,59 @@ def at_epoch_begin(self, trainer, **kwargs):
modified trainer attributes, where the name must correspond to the
trainer's attribute name
Notes
-----
The basetrainer calls the callbacks with the following additional
arguments: `val_metrics`(dict), `val_score_key`(str), `curr_epoch`(int)
"""
return {}

def at_epoch_end(self, trainer, **kwargs):
"""
Function which will be executed at end of each epoch
Parameters
----------
trainer : :class:`AbstractNetworkTrainer`
**kwargs :
additional keyword arguments
Returns
-------
dict
modified trainer attributes, where the name must correspond to the
trainer's attribute name
Notes
-----
The basetrainer calls the callbacks with the following additional
arguments: `val_metrics`(dict), `val_score_key`(str), `curr_epoch`(int)
"""
return {}

def at_training_begin(self, trainer, **kwargs):
"""
Function which will be executed at begin of training
Parameters
----------
trainer : :class:`AbstractNetworkTrainer`
**kwargs :
additional keyword arguments
Returns
-------
dict
modified trainer attributes, where the name must correspond to the
trainer's attribute name
"""
return {}

def at_training_end(self, trainer, **kwargs):
"""
Function which will be executed at end of training
Parameters
----------
trainer : :class:`AbstractNetworkTrainer`
Expand Down

0 comments on commit b491b62

Please sign in to comment.