Skip to content

Commit

Permalink
initial changes for astropy#1330
Browse files Browse the repository at this point in the history
  • Loading branch information
nden authored and embray committed Nov 1, 2013
1 parent e27b06a commit ade64da
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 405 deletions.
77 changes: 74 additions & 3 deletions astropy/modeling/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class ParametricModel(Model):
"""

linear = False
deriv = None
# Flag that indicates if the model derivatives are given in columns
# or rows
col_deriv = True
Expand Down Expand Up @@ -407,6 +408,7 @@ def __init__(self, param_dim=1, **params):
self._constraints['bounds'] = bounds

self._initialize_parameters(params)


@property
def fixed(self):
Expand Down Expand Up @@ -549,6 +551,43 @@ def set_joint_parameters(self, jparams):
"""
self.joint = jparams

def _model_to_fit_params(self):
"""
Create a set of parameters to be fitted.
These may be a subset of the model parameters, if some of them are held
constant or tied.
"""
fitparam_indices = range(len(self.param_names))
if any(self.fixed.values()) or any(self.tied.values()):
params = list(self.parameters)
for idx, name in list(enumerate(self.param_names))[::-1]:
if self.fixed[name] or self.tied[name]:
sl = self._param_metrics[name][0]
del params[sl]
del fitparam_indices[idx]
return (np.array(params), fitparam_indices)
else:
return (self.parameters, fitparam_indices)

def fit_parameters(self, fps):
_fit_parameters, _fit_param_indices = self._model_to_fit_params()
if any(self.fixed.values()) or any(self.tied.values()):
self.parameters[_fit_param_indices] = fps
for idx, name in enumerate(self.param_names):
if self.tied[name] != False:
value = self.tied[name](self)
slice_ = self._param_metrics[name][0]
self.parameters[slice_] = value
elif any([tuple(b) != (None, None) for b in self.bounds.values()]):
for name, par in zip(self.param_names, _fit_parameters):
if self.bounds[name] != (None, None):
par = max(par, self.bounds[name][0])
par = min(par, self.bounds[name][1])
setattr(self, name, par)
else:
self.parameters = fps

def _initialize_parameters(self, params):
"""
Initialize the _parameters array that stores raw parameter values for
Expand Down Expand Up @@ -595,6 +634,41 @@ def _initialize_parameters(self, params):
for name, value in params.items():
setattr(self, name, value)

def _wrap_deriv(self, params, model, x, y, z=None):
"""
Wraps the method calculating the Jacobian of the function to account
for model constraints.
Currently the only fitter that uses a derivative is the
`NonLinearLSQFitter`. This wrapper may need to be revised when other
fitters using function derivative are added or when the statistic is
separated from the fitting routines.
`~scipy.optimize.leastsq` expects the function derivative to have the
above signature (parlist, (argtuple)). In order to accomodate model
constraints, instead of using p directly, we set the parameter list in
this function.
"""
if any(self.fixed.values()) or any(self.tied.values()):

if z is None:
full_deriv = np.array(self.deriv(x, *self.parameters))
else:
full_deriv = np.array(self.deriv(x, y, *self.parameters))
pars = [getattr(self, name) for name in self.param_names]
fixed = [par.fixed for par in pars]
tied = [par.tied for par in pars]
tied = list(np.where([par.tied != False for par in pars], True, tied))
fix_and_tie = np.logical_or(fixed, tied)
ind = np.logical_not(fix_and_tie)
res = full_deriv[np.nonzero(ind)]
result = [np.ravel(_) for _ in res]
return result
else:
if z is None:
return self.deriv(x, *params)
else:
return [np.ravel(_) for _ in self.deriv(x, y, *params)]

class LabeledInput(dict):
"""
Expand Down Expand Up @@ -957,8 +1031,6 @@ class Parametric1DModel(ParametricModel):
{'parameter_name': 'parameter_value'}
"""

deriv = None

@format_input
def __call__(self, x):
"""
Expand Down Expand Up @@ -987,7 +1059,6 @@ class Parametric2DModel(ParametricModel):
{'parameter_name': 'parameter_value'}
"""

deriv = None
n_inputs = 2
n_outputs = 1

Expand Down

0 comments on commit ade64da

Please sign in to comment.