Skip to content

Commit

Permalink
Changed the way additional penalties are added to residual
Browse files Browse the repository at this point in the history
  • Loading branch information
Joern Weissenborn committed Nov 11, 2019
1 parent bed5c10 commit fc4b97d
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 30 deletions.
44 changes: 22 additions & 22 deletions glotaran/analysis/residual_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def create_index_independend_ungrouped_residual(
residuals[label].append(residual)
penalties.append(residual)

if callable(scheme.model.additional_penalty_function):
additional_penalty = dask.delayed(scheme.model.additional_penalty_function)(
parameter, reduced_clp_labels[label], reduced_clps[label], i
)
penalties.append(additional_penalty)
if callable(scheme.model.has_additional_penalty_function):
if scheme.model.has_additional_penalty_function():
additional_penalty = dask.delayed(scheme.model.additional_penalty_function)(
parameter, reduced_clp_labels[label], reduced_clps[label], i
)
penalties.append(additional_penalty)

penalty = dask.delayed(np.concatenate)(penalties)
return reduced_clp_labels, reduced_clps, residuals, penalty
Expand Down Expand Up @@ -63,13 +64,12 @@ def create_index_dependend_ungrouped_residual(
residuals[label].append(residual)
penalties.append(residual)

# call removed because of performance reasons (issue #230)
# if callable(scheme.model.additional_penalty_function):
# additional_penalty = dask.delayed(scheme.model.additional_penalty_function)(
# parameter, clp_label, clp, i
# )
# penalties.append(additional_penalty)
# TODO: re-implement removed functionality (issue: #237)
if callable(scheme.model.has_additional_penalty_function):
if scheme.model.has_additional_penalty_function():
additional_penalty = dask.delayed(scheme.model.additional_penalty_function)(
parameter, clp_label, clp, i
)
penalties.append(additional_penalty)

penalty = dask.delayed(np.concatenate)(penalties)
return reduced_clp_labels, reduced_clps, residuals, penalty
Expand All @@ -86,11 +86,11 @@ def penalty_function(matrix_label, problem, labels_and_matrices):
clp, residual = residual_function(labels_and_matrices[matrix_label].matrix, problem.data)

penalty = residual
if callable(scheme.model.additional_penalty_function):
additional_penalty = scheme.model.additional_penalty_function(
parameter, labels_and_matrices[matrix_label].clp_label, clp, problem.index
)
if additional_penalty:
if callable(scheme.model.has_additional_penalty_function):
if scheme.model.has_additional_penalty_function():
additional_penalty = scheme.model.additional_penalty_function(
parameter, labels_and_matrices[matrix_label].clp_label, clp, problem.index
)
penalty = np.concatenate([penalty, additional_penalty])
return clp, residual, penalty
penalty_bag = \
Expand All @@ -113,11 +113,11 @@ def penalty_function(problem, labels_and_matrices):
clp, residual = residual_function(labels_and_matrices.matrix, problem.data)

penalty = residual
if callable(scheme.model.additional_penalty_function):
additional_penalty = scheme.model.additional_penalty_function(
parameter, labels_and_matrices.clp_label, clp, problem.index
)
if additional_penalty:
if callable(scheme.model.has_additional_penalty_function):
if scheme.model.has_additional_penalty_function():
additional_penalty = scheme.model.additional_penalty_function(
parameter, labels_and_matrices.clp_label, clp, problem.index
)
penalty = np.concatenate([penalty, additional_penalty])
return clp, residual, penalty
penalty_bag = \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .spectral_constraints import SpectralConstraint, apply_spectral_constraints
from .spectral_irf import IrfSpectralMultiGaussian
from .spectral_matrix import spectral_matrix
from .spectral_penalties import EqualAreaPenalty, apply_spectral_penalties
from .spectral_penalties import EqualAreaPenalty, has_spectral_penalties, apply_spectral_penalties
from .spectral_relations import SpectralRelation, apply_spectral_relations
from .spectral_shape import SpectralShape

Expand Down Expand Up @@ -61,6 +61,7 @@ def grouped(model: typing.Type['KineticModel']):
global_matrix=spectral_matrix,
global_dimension='spectral',
constrain_matrix_function=apply_kinetic_model_constraints,
has_additional_penalty_function=has_spectral_penalties,
additional_penalty_function=apply_spectral_penalties,
grouped=grouped,
index_dependend=index_dependend,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ def applies(interval):
return any([applies(i) for i in self.interval])


def has_spectral_penalties(model: typing.Type['KineticModel']) -> bool:
return len(model.equal_area_penalties) != 0


def apply_spectral_penalties(
model: typing.Type['KineticModel'],
parameter: ParameterGroup,
clp_labels: typing.List[str],
clps: np.ndarray,
index: float) -> np.ndarray:

if not model.equal_area_penalties:
return []

clp_labels, clps = retrieve_clps(model, parameter, clp_labels, clps, index)

penalties = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_spectral_penalties():
},
})

weight = 0.0 # TODO: workaround for #230 should be fixed with #237
weight = 0.1
model_with_penalty = KineticSpectrumModel.from_dict({
'initial_concentration': {
'j1': {
Expand Down
15 changes: 12 additions & 3 deletions glotaran/model/model_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def model(model_type: str,
matrix_dimension: str = None,
global_dimension: str = None,
constrain_matrix_function: ConstrainMatrixFunction = None,
has_additional_penalty_function: typing.Callable[[], bool] = None,
additional_penalty_function: PenaltyFunction = None,
finalize_data_function: FinalizeFunction = None,
grouped: typing.Union[bool, typing.Callable[[None], bool]] = False,
index_dependend: typing.Union[bool, typing.Callable[[None], bool]] = False,
grouped: typing.Union[bool, typing.Callable[[], bool]] = False,
index_dependend: typing.Union[bool, typing.Callable[[], bool]] = False,
) -> typing.Callable:
"""The `@model` decorator is intended to be used on subclasses of :class:`glotaran.model.Model`.
It creates properties for the given attributes as well as functions to add access them. Also it
Expand Down Expand Up @@ -106,12 +107,20 @@ def decorator(cls):
else:
setattr(cls, 'constrain_matrix_function', None)

if additional_penalty_function:
if has_additional_penalty_function:
if not additional_penalty_function:
raise Exception('Model implements `has_additional_penalty_function` '
'but not `additional_penalty_function`')
has_pen = wrap_func_as_method(
cls, name='has_additional_penalty_function'
)(has_additional_penalty_function)
pen = wrap_func_as_method(
cls, name='additional_penalty_function'
)(additional_penalty_function)
setattr(cls, 'additional_penalty_function', pen)
setattr(cls, 'has_additional_penalty_function', has_pen)
else:
setattr(cls, 'has_additional_penalty_function', None)
setattr(cls, 'additional_penalty_function', None)

if not callable(grouped):
Expand Down

0 comments on commit fc4b97d

Please sign in to comment.