Skip to content

Commit

Permalink
Merge pull request astropy#10623 from nden/model-sets-linear1d
Browse files Browse the repository at this point in the history
Linear1D and Planar2D should be able to fit as model sets
  • Loading branch information
nden authored and eteq committed Oct 12, 2020
1 parent f4905f5 commit 4ced0c6
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 89 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Expand Up @@ -703,6 +703,7 @@ astropy.wcs
Bug fixes
---------


astropy.config
^^^^^^^^^^^^^^

Expand Down Expand Up @@ -809,6 +810,8 @@ astropy.modeling
without units when the expression involves operators other than addition
and subtraction. [#10415]

- Fixed a problem with fitting ``Linear1D`` and ``Planar2D`` in model sets. [#10623]

astropy.nddata
^^^^^^^^^^^^^^

Expand Down
29 changes: 13 additions & 16 deletions astropy/modeling/fitting.py
Expand Up @@ -403,14 +403,12 @@ def __call__(self, model, x, y, z=None, weights=None, rcond=None):
if hasattr(model_copy, 'domain'):
x = self._map_domain_window(model_copy, x)
if has_fixed:
lhs = self._deriv_with_constraints(model_copy,
fitparam_indices,
x=x)
fixderivs = self._deriv_with_constraints(model_copy,
fixparam_indices,
x=x)
lhs = np.asarray(self._deriv_with_constraints(model_copy,
fitparam_indices,
x=x))
fixderivs = self._deriv_with_constraints(model_copy, fixparam_indices, x=x)
else:
lhs = model_copy.fit_deriv(x, *model_copy.parameters)
lhs = np.asarray(model_copy.fit_deriv(x, *model_copy.parameters))
sum_of_implicit_terms = model_copy.sum_of_implicit_terms(x)
rhs = y
else:
Expand All @@ -421,12 +419,13 @@ def __call__(self, model, x, y, z=None, weights=None, rcond=None):
x, y = self._map_domain_window(model_copy, x, y)

if has_fixed:
lhs = self._deriv_with_constraints(model_copy,
fitparam_indices, x=x, y=y)
lhs = np.asarray(self._deriv_with_constraints(model_copy,
fitparam_indices, x=x, y=y))
fixderivs = self._deriv_with_constraints(model_copy,
fixparam_indices, x=x, y=y)
fixparam_indices,
x=x, y=y)
else:
lhs = model_copy.fit_deriv(x, y, *model_copy.parameters)
lhs = np.asanyarray(model_copy.fit_deriv(x, y, *model_copy.parameters))
sum_of_implicit_terms = model_copy.sum_of_implicit_terms(x, y)

if len(model_copy) > 1:
Expand Down Expand Up @@ -458,7 +457,7 @@ def __call__(self, model, x, y, z=None, weights=None, rcond=None):
# when constructing their Vandermonde matrix, which can lead to obscure
# failures below. Ultimately, np.linalg.lstsq can't handle >2D matrices,
# so just raise a slightly more informative error when this happens:
if lhs.ndim > 2:
if np.asanyarray(lhs).ndim > 2:
raise ValueError('{} gives unsupported >2D derivative matrix for '
'this x/y'.format(type(model_copy).__name__))

Expand Down Expand Up @@ -495,9 +494,6 @@ def __call__(self, model, x, y, z=None, weights=None, rcond=None):
lhs *= weights[:, np.newaxis]
rhs = rhs * weights

if rcond is None:
rcond = len(x) * np.finfo(x.dtype).eps

scl = (lhs * lhs).sum(0)
lhs /= scl

Expand Down Expand Up @@ -551,7 +547,8 @@ def __call__(self, model, x, y, z=None, weights=None, rcond=None):

# TODO: Only Polynomial models currently have an _order attribute;
# maybe change this to read isinstance(model, PolynomialBase)
if hasattr(model_copy, '_order') and rank != model_copy._order:
if hasattr(model_copy, '_order') and len(model_copy) == 1 \
and not has_fixed and rank != model_copy._order:
warnings.warn("The fit may be poorly conditioned\n",
AstropyUserWarning)

Expand Down
5 changes: 2 additions & 3 deletions astropy/modeling/functional_models.py
Expand Up @@ -823,7 +823,6 @@ class Linear1D(Fittable1DModel):
.. math:: f(x) = a x + b
"""

slope = Parameter(default=1)
intercept = Parameter(default=0)
linear = True
Expand All @@ -835,7 +834,7 @@ def evaluate(x, slope, intercept):
return slope * x + intercept

@staticmethod
def fit_deriv(x, slope, intercept):
def fit_deriv(x, *params):
"""One dimensional Line model derivative with respect to parameters"""

d_slope = x
Expand Down Expand Up @@ -893,7 +892,7 @@ def evaluate(x, y, slope_x, slope_y, intercept):
return slope_x * x + slope_y * y + intercept

@staticmethod
def fit_deriv(x, y, slope_x, slope_y, intercept):
def fit_deriv(x, y, *params):
"""Two dimensional Plane model derivative with respect to parameters"""

d_slope_x = x
Expand Down
4 changes: 1 addition & 3 deletions astropy/modeling/tests/test_constraints.py
Expand Up @@ -232,9 +232,7 @@ def test(self):
self.p1.c0.fixed = True
self.p1.c1.fixed = True
pfit = fitting.LinearLSQFitter()
with pytest.warns(AstropyUserWarning,
match=r'The fit may be poorly conditioned'):
model = pfit(self.p1, self.x, self.y)
model = pfit(self.p1, self.x, self.y)
assert_allclose(self.y, model(self.x))

# Test constraints as parameter properties
Expand Down
23 changes: 8 additions & 15 deletions astropy/modeling/tests/test_fitters.py
Expand Up @@ -54,16 +54,17 @@ def setup_class(self):
def poly2(x, y):
return 1 + 2 * x + 3 * x ** 2 + 4 * y + 5 * y ** 2 + 6 * x * y
self.z = poly2(self.x, self.y)
self.fitter = LinearLSQFitter()

def test_poly2D_fitting(self):
fitter = LinearLSQFitter()
v = self.model.fit_deriv(x=self.x, y=self.y)
p = linalg.lstsq(v, self.z.flatten(), rcond=-1)[0]
new_model = self.fitter(self.model, self.x, self.y, self.z)
new_model = fitter(self.model, self.x, self.y, self.z)
assert_allclose(new_model.parameters, p)

def test_eval(self):
new_model = self.fitter(self.model, self.x, self.y, self.z)
fitter = LinearLSQFitter()
new_model = fitter(self.model, self.x, self.y, self.z)
assert_allclose(new_model(self.x, self.y), self.z)

@pytest.mark.skipif('not HAS_SCIPY')
Expand Down Expand Up @@ -256,9 +257,7 @@ def test_linear_fit_2d_model_set(self):
z = z_expected + np.random.normal(0, 0.01, size=z_expected.shape)

fitter = LinearLSQFitter()
with pytest.warns(AstropyUserWarning,
match=r'The fit may be poorly conditioned'):
fitted_model = fitter(init_model, x, y, z)
fitted_model = fitter(init_model, x, y, z)
assert_allclose(fitted_model(x, y, model_set_axis=False), z_expected,
rtol=1e-1)

Expand All @@ -273,9 +272,7 @@ def test_linear_fit_fixed_parameter(self):
y = 2 + x + 0.5*x*x

fitter = LinearLSQFitter()
with pytest.warns(AstropyUserWarning,
match=r'The fit may be poorly conditioned'):
fitted_model = fitter(init_model, x, y)
fitted_model = fitter(init_model, x, y)
assert_allclose(fitted_model.parameters, [2., 1., 0.5], atol=1e-14)

def test_linear_fit_model_set_fixed_parameter(self):
Expand All @@ -289,9 +286,7 @@ def test_linear_fit_model_set_fixed_parameter(self):
yy = np.array([2 + x + 0.5*x*x, -2*x])

fitter = LinearLSQFitter()
with pytest.warns(AstropyUserWarning,
match=r'The fit may be poorly conditioned'):
fitted_model = fitter(init_model, x, yy)
fitted_model = fitter(init_model, x, yy)

assert_allclose(fitted_model.c0, [2., 0.], atol=1e-14)
assert_allclose(fitted_model.c1, [1., -2.], atol=1e-14)
Expand All @@ -309,9 +304,7 @@ def test_linear_fit_2d_model_set_fixed_parameters(self):
zz = np.array([1+x-0.5*y+0.1*x*x, 2*x+y-0.2*y*y])

fitter = LinearLSQFitter()
with pytest.warns(AstropyUserWarning,
match=r'The fit may be poorly conditioned'):
fitted_model = fitter(init_model, x, y, zz)
fitted_model = fitter(init_model, x, y, zz)

assert_allclose(fitted_model(x, y, model_set_axis=False), zz,
atol=1e-14)
Expand Down

0 comments on commit 4ced0c6

Please sign in to comment.