Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #141 from justusschock/code-cleanup
Browse files Browse the repository at this point in the history
Remove mutable default arguments
  • Loading branch information
justusschock committed Jun 14, 2019
2 parents c849c6a + 2bee4d4 commit 3c16d7e
Show file tree
Hide file tree
Showing 17 changed files with 168 additions and 86 deletions.
9 changes: 4 additions & 5 deletions delira/data_loading/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import inspect
import logging

import numpy as np
import typing
import inspect
from batchgenerators.dataloading import MultiThreadedAugmenter, \
SingleThreadedAugmenter, SlimDataLoaderBase
from batchgenerators.transforms import AbstractTransform

from delira import get_current_debug_mode
from .data_loader import BaseDataLoader
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset
from .load_utils import default_load_fn_2d
from .sampler import SequentialSampler, AbstractSampler
from ..utils.decorators import make_deprecated
from delira import get_current_debug_mode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -245,7 +242,7 @@ class BaseDataManager(object):

def __init__(self, data, batch_size, n_process_augmentation,
transforms, sampler_cls=SequentialSampler,
sampler_kwargs={},
sampler_kwargs=None,
data_loader_cls=None, dataset_cls=None,
load_fn=default_load_fn_2d, from_disc=True, **kwargs):
"""
Expand Down Expand Up @@ -292,6 +289,8 @@ class defining the sampling strategy
"""

# Instantiate Hidden variables for property access
if sampler_kwargs is None:
sampler_kwargs = {}
self._batch_size = None
self._n_process_augmentation = None
self._transforms = None
Expand Down
11 changes: 8 additions & 3 deletions delira/data_loading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,13 @@ class Nii3DLazyDataset(BaseLazyDataset):
"""

@make_deprecated('LoadSample')
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
img_files, label_file, **load_kwargs):
def __init__(
self,
data_path,
load_fn,
img_files,
label_file,
**load_kwargs):
"""
Parameters
----------
Expand Down Expand Up @@ -639,7 +644,7 @@ class Nii3DCacheDatset(BaseCacheDataset):
"""

@make_deprecated('LoadSample')
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
def __init__(self, data_path, load_fn,
img_files, label_file, **load_kwargs):
"""
Parameters
Expand Down
17 changes: 13 additions & 4 deletions delira/data_loading/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def is_valid_image_file(fname, img_extensions, gt_extensions):
----------
fname : str
filename of image path
Returns
img_extensions : list
list of valid image file extensions
gt_extensions : list
list of valid gt file extensions
Returns
-------
bool
is valid data sample
Expand Down Expand Up @@ -148,7 +152,7 @@ class LoadSample:
def __init__(self,
sample_ext: dict,
sample_fn: collections.abc.Callable,
dtype={}, normalize=(), norm_fn=norm_range('-1,1'),
dtype=None, normalize=(), norm_fn=norm_range('-1,1'),
**kwargs):
"""
Expand Down Expand Up @@ -185,6 +189,8 @@ def __init__(self,
>>> 'seg': 'uint8'},
>>> normalize=('data',))
"""
if dtype is None:
dtype = {}
self._sample_ext = sample_ext
self._sample_fn = sample_fn
self._dtype = dtype
Expand Down Expand Up @@ -235,7 +241,7 @@ def __init__(self,
sample_fn: collections.abc.Callable,
label_ext: collections.abc.Iterable,
label_fn: collections.abc.Callable,
sample_kwargs={}, **kwargs):
sample_kwargs=None, **kwargs):
"""
Load sample and label from folder
Expand Down Expand Up @@ -264,6 +270,9 @@ def __init__(self,
--------
:class: `LoadSample`
"""
if sample_kwargs is None:
sample_kwargs = {}

super().__init__(sample_ext, sample_fn, **sample_kwargs)
self._label_ext = label_ext
self._label_fn = label_fn
Expand All @@ -282,7 +291,7 @@ def __call__(self, path):
dict
dict with data and label
"""
sample_dict = super(LoadSampleLabel, self).__call__(path)
sample_dict = super().__call__(path)
label_dict = self._label_fn(os.path.join(path, self._label_ext),
**self._label_kwargs)
sample_dict.update(label_dict)
Expand Down
4 changes: 3 additions & 1 deletion delira/io/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from ..models import AbstractPyTorchNetwork

def save_checkpoint(file: str, model=None, optimizers={},
def save_checkpoint(file: str, model=None, optimizers=None,
epoch=None, **kwargs):
"""
Save model's parameters
Expand All @@ -28,6 +28,8 @@ def save_checkpoint(file: str, model=None, optimizers={},
current epoch (will also be pickled)
"""
if optimizers is None:
optimizers = {}
if isinstance(model, torch.nn.DataParallel):
_model = model.module
else:
Expand Down
8 changes: 6 additions & 2 deletions delira/models/abstract_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __call__(self, *args, **kwargs):

@staticmethod
@abc.abstractmethod
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
def closure(model, data_dict: dict, optimizers: dict, losses=None,
metrics=None, fold=0, **kwargs):
"""
Function which handles prediction from batch, logging, loss calculation
and optimizer step
Expand Down Expand Up @@ -90,6 +90,10 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},
If not overwritten by subclass
"""
if losses is None:
losses = {}
if metrics is None:
metrics = {}
raise NotImplementedError()

@staticmethod
Expand Down
6 changes: 5 additions & 1 deletion delira/models/classification/classification_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, input_batch: torch.Tensor):

@staticmethod
def closure(model: AbstractPyTorchNetwork, data_dict: dict,
optimizers: dict, losses={}, metrics={},
optimizers: dict, losses=None, metrics=None,
fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand Down Expand Up @@ -108,6 +108,10 @@ def closure(model: AbstractPyTorchNetwork, data_dict: dict,
"""

if losses is None:
losses = {}
if metrics is None:
metrics = {}
assert (optimizers and losses) or not optimizers, \
"Criterion dict cannot be emtpy, if optimizers are passed"

Expand Down
5 changes: 3 additions & 2 deletions delira/models/classification/classification_network_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _build_model(n_outputs: int, **kwargs):

@staticmethod
def closure(model: typing.Type[AbstractTfNetwork], data_dict: dict,
metrics={}, fold=0, **kwargs):
metrics=None, fold=0, **kwargs):
"""
closure method to do a single prediction.
This is followed by backpropagation or not based state of
Expand Down Expand Up @@ -163,9 +163,10 @@ def closure(model: typing.Type[AbstractTfNetwork], data_dict: dict,
"""

if metrics is None:
metrics = {}
loss_vals = {}
metric_vals = {}
image_names = "input_images"

inputs = data_dict.pop('data')

Expand Down
6 changes: 5 additions & 1 deletion delira/models/gan/generative_adversarial_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(self, real_image_batch):

@staticmethod
def closure(model, data_dict: dict,
optimizers: dict, losses={}, metrics={},
optimizers: dict, losses=None, metrics=None,
fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand Down Expand Up @@ -134,6 +134,10 @@ def closure(model, data_dict: dict,
"""

if losses is None:
losses = {}
if metrics is None:
metrics = {}
loss_vals = {}
metric_vals = {}
total_loss_discr_real = 0
Expand Down
17 changes: 12 additions & 5 deletions delira/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn.functional as F
from torch.nn import init
import logging
from ..abstract_network import AbstractPyTorchNetwork

class UNet2dPyTorch(AbstractPyTorchNetwork):
Expand Down Expand Up @@ -175,8 +174,8 @@ def forward(self, x):
return {"pred": x}

@staticmethod
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
def closure(model, data_dict: dict, optimizers: dict, losses=None,
metrics=None, fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand Down Expand Up @@ -216,6 +215,10 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},
"""

if losses is None:
losses = {}
if metrics is None:
metrics = {}
assert (optimizers and losses) or not optimizers, \
"Loss dict cannot be emtpy, if optimizers are passed"

Expand Down Expand Up @@ -618,8 +621,8 @@ def forward(self, x):
return {"pred": x}

@staticmethod
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
def closure(model, data_dict: dict, optimizers: dict, losses=None,
metrics=None, fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand Down Expand Up @@ -659,6 +662,10 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},
"""

if losses is None:
losses = {}
if metrics is None:
metrics = {}
assert (optimizers and losses) or not optimizers, \
"Loss dict cannot be emtpy, if optimizers are passed"

Expand Down
20 changes: 10 additions & 10 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from abc import abstractmethod
import logging
import os
import pickle
import typing

from ..data_loading.data_manager import Augmenter

from .predictor import Predictor
from .callbacks import AbstractCallback
from ..models import AbstractNetwork

import numpy as np
import os
from tqdm import tqdm

from delira.logging import TrixiHandler
from .callbacks import AbstractCallback
from .predictor import Predictor
from ..data_loading.data_manager import Augmenter
from ..models import AbstractNetwork

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -614,7 +612,7 @@ def _update_state(self, new_state):
"""
for key, val in new_state.items():
if (key.startswith("__") and key.endswith("__")):
if key.startswith("__") and key.endswith("__"):
continue

try:
Expand Down Expand Up @@ -729,7 +727,7 @@ def _reinitialize_logging(self, logging_type, logging_kwargs: dict):
handlers=new_handlers)

@staticmethod
def _search_for_prev_state(path, extensions=[]):
def _search_for_prev_state(path, extensions=None):
"""
Helper function to search in a given path for previous epoch states
(indicated by extensions)
Expand All @@ -752,6 +750,8 @@ def _search_for_prev_state(path, extensions=[]):
the latest epoch (1 if no checkpoint was found)
"""
if extensions is None:
extensions = []
files = []
for file in os.listdir(path):
for ext in extensions:
Expand Down
17 changes: 11 additions & 6 deletions delira/training/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ class Parameters(LookupConfig):
"""

def __init__(self, fixed_params={"model": {},
"training": {}},
variable_params={"model": {},
"training": {}}):
def __init__(self, fixed_params=None,
variable_params=None):
"""
Parameters
Expand All @@ -31,6 +29,13 @@ def __init__(self, fixed_params={"model": {},
variable parameters (can be variated by a hyperparameter search)
"""

if variable_params is None:
variable_params = {"model": {},
"training": {}}
if fixed_params is None:
fixed_params = {"model": {},
"training": {}}

super().__init__(fixed=fixed_params,
variable=variable_params)

Expand Down Expand Up @@ -238,8 +243,8 @@ def update(self, dict_like, deep=False, ignore=None,
overwrite a 'regular' value with a dict/Config at lower levels.
See examples for an illustration of the difference
Examples:
---------
Examples
--------
The following illustrates the update behaviour if
:obj:allow_dict_overwrite is active. If it isn't, an AttributeError
would be raised, originating from trying to update "string"::
Expand Down

0 comments on commit 3c16d7e

Please sign in to comment.