Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #137 from justusschock/small_fixes
Browse files Browse the repository at this point in the history
Minor changes
  • Loading branch information
justusschock committed Jun 12, 2019
2 parents c4cc60c + 19c9bda commit 2c4f2dc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
11 changes: 6 additions & 5 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def __init__(self,
start_epoch : int
epoch to start training at
metric_keys : dict
dict specifying which batch_dict entry to use for which metric as
target; default: None, which will result in key "label" for all
metrics
the batch_dict keys to use for each metric to calculate.
Should contain a value for each key in ``metrics``.
If no values are given for a key, per default ``pred`` and
``label`` will be used for metric calculation
convert_batch_to_npy_fn : type, optional
function converting a batch-tensor to numpy, per default this is
the identity function
Expand Down Expand Up @@ -470,8 +471,8 @@ def reduce_fn(batch):
best_val_score, new_val_score, val_score_mode)

# set best_val_score to new_val_score if is_best
best_val_score = int(is_best) * new_val_score + \
(1 - int(is_best)) * best_val_score
if is_best:
best_val_score = new_val_score

if is_best and verbose:
logging.info("New Best Value at Epoch %03d : %03.3f" %
Expand Down
11 changes: 8 additions & 3 deletions delira/training/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self,
optim_builder=None,
checkpoint_freq=1,
trainer_cls=BaseNetworkTrainer,
predictor_cls=Predictor,
**kwargs):
"""
Expand Down Expand Up @@ -83,6 +84,8 @@ def __init__(self,
2 denotes saving every second epoch etc.); default: 1
trainer_cls : subclass of :class:`BaseNetworkTrainer`
the trainer class to use for training the model
predictor_cls : subclass of :class:`Predictor`
the predictor class to use for testing the model
**kwargs :
additional keyword arguments
Expand Down Expand Up @@ -117,6 +120,7 @@ def __init__(self,
os.makedirs(self.save_path, exist_ok=True)

self.trainer_cls = trainer_cls
self.predictor_cls = predictor_cls

if val_score_key is None:
if params.nested_get("val_metrics", False):
Expand Down Expand Up @@ -252,9 +256,10 @@ def _setup_test(self, params, model, convert_batch_to_npy_fn,
the created predictor
"""
predictor = Predictor(model=model, key_mapping=self.key_mapping,
convert_batch_to_npy_fn=convert_batch_to_npy_fn,
prepare_batch_fn=prepare_batch_fn, **kwargs)
predictor = self.predictor_cls(
model=model, key_mapping=self.key_mapping,
convert_batch_to_npy_fn=convert_batch_to_npy_fn,
prepare_batch_fn=prepare_batch_fn, **kwargs)
return predictor

def run(self, train_data: BaseDataManager,
Expand Down

0 comments on commit 2c4f2dc

Please sign in to comment.