/
regressioncorrector.py
336 lines (272 loc) · 13.9 KB
/
regressioncorrector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""Defines RegressionCorrector."""
from __future__ import division # necessary for math in `_fit_coefficients`
import logging
import warnings
import matplotlib.pyplot as plt
import numpy as np
from astropy.stats import sigma_clip
from .corrector import Corrector
from .designmatrix import DesignMatrix, DesignMatrixCollection
from ..lightcurve import LightCurve, MPLSTYLE
__all__ = ['RegressionCorrector']
log = logging.getLogger(__name__)
class RegressionCorrector(Corrector):
"""Remove noise using linear regression against a `.DesignMatrix`.
.. math::
\\newcommand{\\y}{\\mathbf{y}}
\\newcommand{\\cov}{\\boldsymbol\Sigma_\y}
\\newcommand{\\w}{\\mathbf{w}}
\\newcommand{\\covw}{\\boldsymbol\Sigma_\w}
\\newcommand{\\muw}{\\boldsymbol\mu_\w}
\\newcommand{\\sigw}{\\boldsymbol\sigma_\w}
\\newcommand{\\varw}{\\boldsymbol\sigma^2_\w}
Given a column vector of data :math:`\y`
and a design matrix of regressors :math:`X`,
we will find the vector of coefficients :math:`\w`
such that:
.. math::
\mathbf{y} = X\mathbf{w} + \mathrm{noise}
We will assume that the model fits the data within Gaussian uncertainties:
.. math::
p(\y | \w) = \mathcal{N}(X\w, \cov)
We make the regression robust by placing Gaussian priors on :math:`\w`:
.. math::
p(\w) = \mathcal{N}(\muw, \sigw)
We can then find the maximum likelihood solution of the posterior
distribution :math:`p(\w | \y) \propto p(\y | \w) p(\w)` by solving
the matrix equation:
.. math::
\w = \covw (X^\\top \cov^{-1} \y + \\boldsymbol\sigma^{-2}_\w I \muw)
Where :math:`\covw` is the covariance matrix of the coefficients:
.. math::
\covw = (X^\\top \cov^{-1} X + \\boldsymbol\sigma^{-2}_\w I)^{-1}
Parameters
----------
lc : `.LightCurve`
The light curve that needs to be corrected.
"""
def __init__(self, lc):
# We don't accept NaN in time or flux.
if np.any([~np.isfinite(lc.time), ~np.isfinite(lc.flux)]):
raise ValueError('Input light curve has NaNs in time or flux. '
'Please remove NaNs before correction '
'(e.g. using `lc = lc.remove_nans()`).')
# We don't accept NaN in flux_err, unless all values are NaN.
if np.any(~np.isfinite(lc.flux_err)) and not np.all(~np.isfinite(lc.flux_err)):
raise ValueError('Input light curve has NaNs in `flux_err`. '
'Please remove NaNs before correction '
'(e.g. using `lc = lc.remove_nans()`).')
if np.any(lc.flux_err[np.isfinite(lc.flux_err)] <= 0):
raise ValueError('Input light curve contains flux uncertainties '
'smaller than or equal to zero. Please remove '
'these (e.g. using `lc = lc[lc.flux_err > 0]`).')
self.lc = lc
# The following properties will be set when correct() is called.
# We're setting them here so they do not throw value errors
self.design_matrix_collection = None
self.coefficients = None
self.corrected_lc = None
self.model_lc = None
self.diagnostic_lightcurves = None
def __repr__(self):
return 'RegressionCorrector (ID: {})'.format(self.lc.targetid)
@property
def X(self):
"""Shorthand for self.design_matrix_collection."""
return self.design_matrix_collection
def _fit_coefficients(self, cadence_mask=None, prior_mu=None,
prior_sigma=None, propagate_errors=False):
"""Fit the linear regression coefficients.
This function will solve a linear regression with Gaussian priors
on the coefficients.
Parameters
----------
cadence_mask : np.ndarray of bool
Mask, where True indicates a cadence that should be used.
Returns
-------
coefficients : np.ndarray
The best fit model coefficients to the data.
"""
if prior_mu is not None:
if len(prior_mu) != len(self.X.values.T):
raise ValueError('`prior_mu` must have shape {}'
''.format(len(self.X.values.T)))
if prior_sigma is not None:
if len(prior_sigma) != len(self.X.values.T):
raise ValueError('`prior_sigma` must have shape {}'
''.format(len(self.X.values.T)))
if np.any(prior_sigma <= 0):
raise ValueError('`prior_sigma` values cannot be smaller than '
'or equal to zero')
# If prior_mu is specified, prior_sigma must be specified
if not ((prior_mu is None) & (prior_sigma is None)) | \
((prior_mu is not None) & (prior_sigma is not None)):
raise ValueError("Please specify both `prior_mu` and `prior_sigma`")
# Default cadence mask
if cadence_mask is None:
cadence_mask = np.ones(len(self.lc.flux), bool)
# If flux errors are not all finite numbers, then default to array of ones
if np.all(~np.isfinite(self.lc.flux_err)):
flux_err = np.ones(cadence_mask.sum())
else:
flux_err = self.lc.flux_err[cadence_mask]
# Retrieve the design matrix (X) as a numpy array
X = self.X.values[cadence_mask]
# Compute `X^T cov^-1 X + 1/prior_sigma^2`
sigma_w_inv = np.dot(X.T, X / flux_err[:, None]**2)
if prior_sigma is not None:
sigma_w_inv += np.diag(1. / prior_sigma**2)
# Compute `X^T cov^-1 y + prior_mu/prior_sigma^2`
B = np.dot(X.T, self.lc.flux[cadence_mask] / flux_err**2)
if prior_sigma is not None:
B += (prior_mu / prior_sigma**2)
# Solve for weights w
w = np.linalg.solve(sigma_w_inv, B).T
if propagate_errors:
w_err = np.linalg.inv(sigma_w_inv)
else:
w_err = np.zeros(len(w)) * np.nan
return w, w_err
def correct(self, design_matrix_collection, cadence_mask=None, sigma=5,
niters=5, propagate_errors=False):
"""Find the best fit correction for the light curve.
Parameters
----------
design_matrix_collection : `.DesignMatrix` or `.DesignMatrixCollection`
One or more design matrices. Each matrix must have a shape of
(time, regressors). The columns contained in each matrix must be
known to correlate with additive noise components we want to remove
from the light curve.
cadence_mask : np.ndarray of bools (optional)
Mask, where True indicates a cadence that should be used.
sigma : int (default 5)
Standard deviation at which to remove outliers from fitting
niters : int (default 5)
Number of iterations to fit and remove outliers
propagate_errors : bool (default False)
Whether to propagate the uncertainties from the regression. Default is False.
Setting to True will increase run time, but will sample from multivariate normal
distribution of weights.
Returns
-------
`.LightCurve`
Corrected light curve, with noise removed.
"""
if isinstance(design_matrix_collection, DesignMatrix):
design_matrix_collection = DesignMatrixCollection([design_matrix_collection])
design_matrix_collection._validate()
self.design_matrix_collection = design_matrix_collection
if cadence_mask is None:
cadence_mask = np.ones(len(self.lc.time), bool)
else:
cadence_mask = np.copy(cadence_mask)
# Prepare for iterative masking of residuals
clean_cadences = np.ones_like(cadence_mask)
# Iterative sigma clipping
for count in range(niters):
coefficients, coefficients_err = \
self._fit_coefficients(cadence_mask=cadence_mask & clean_cadences,
prior_mu=self.X.prior_mu,
prior_sigma=self.X.prior_sigma,
propagate_errors=propagate_errors)
model = np.ma.masked_array(data=np.dot(self.X.values, coefficients),
mask=~(cadence_mask & clean_cadences))
residuals = self.lc.flux - model
clean_cadences = ~sigma_clip(residuals, sigma=sigma).mask
log.debug("correct(): iteration {}: clipped {} cadences"
"".format(count, (~clean_cadences).sum()))
self.cadence_mask = cadence_mask & clean_cadences
self.coefficients = coefficients
self.coefficients_err = coefficients_err
model_flux = np.dot(self.X.values, coefficients)
model_flux -= np.median(model_flux)
if propagate_errors:
with warnings.catch_warnings():
# ignore "RuntimeWarning: covariance is not symmetric positive-semidefinite."
warnings.simplefilter("ignore", RuntimeWarning)
samples = np.asarray(
[np.dot(self.X.values,
np.random.multivariate_normal(coefficients, coefficients_err))
for idx in range(100)]).T
model_err = np.abs(np.percentile(samples, [16, 84], axis=1) - np.median(samples, axis=1)[:, None].T).mean(axis=0)
else:
model_err = np.zeros(len(model_flux))
self.model_lc = LightCurve(self.lc.time, model_flux, model_err)
self.corrected_lc = self.lc.copy()
self.corrected_lc.flux = self.lc.flux - self.model_lc.flux
self.corrected_lc.flux_err = (self.lc.flux_err**2 + model_err**2)**0.5
self.diagnostic_lightcurves = self._create_diagnostic_lightcurves()
return self.corrected_lc
def _create_diagnostic_lightcurves(self):
"""Returns a dictionary containing all diagnostic light curves.
The dictionary will provide a light curve for each matrix in the
design matrix collection.
"""
if self.coefficients is None:
raise ValueError("you need to call `correct()` first")
lcs = {}
for idx, submatrix in enumerate(self.X.matrices):
# What is the index of the first column for the submatrix?
firstcol_idx = sum([m.shape[1] for m in self.X.matrices[:idx]])
submatrix_coefficients = self.coefficients[firstcol_idx:firstcol_idx+submatrix.shape[1]]
# submatrix_coefficients_err = self.coefficients_err[firstcol_idx:firstcol_idx+submatrix.shape[1], firstcol_idx:firstcol_idx+submatrix.shape[1]]
# samples = np.asarray([np.dot(submatrix.values, np.random.multivariate_normal(submatrix_coefficients, submatrix_coefficients_err)) for idx in range(100)]).T
# model_err = np.abs(np.percentile(samples, [16, 84], axis=1) - np.median(samples, axis=1)[:, None].T).mean(axis=0)
model_flux = np.dot(submatrix.values, submatrix_coefficients)
lcs[submatrix.name] = LightCurve(self.lc.time, model_flux, np.zeros(len(model_flux)), label=submatrix.name)
return lcs
def _diagnostic_plot(self):
"""Produce diagnostic plots to assess the effectiveness of the correction.
Note: We need a hidden function so that other correctors can alter the plot.
"""
if not hasattr(self, 'corrected_lc'):
raise ValueError('Please call the `correct()` method before trying to diagnose.')
with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(2, figsize=(10, 6), sharex=True)
ax = axs[0]
self.lc.plot(ax=ax, normalize=False, label='original', alpha=0.4)
for key in self.diagnostic_lightcurves.keys():
(self.diagnostic_lightcurves[key] - np.median(self.diagnostic_lightcurves[key].flux) + np.median(self.lc.flux)).plot(ax=ax)
ax.set_xlabel('')
ax = axs[1]
self.lc.plot(ax=ax, normalize=False, alpha=0.2, label='Original')
self.corrected_lc[~self.cadence_mask].scatter(
normalize=False, c='r', marker='x',
s=10, label='Outliers', ax=ax)
self.corrected_lc.plot(normalize=False, label='Corrected', ax=ax, c='k')
return axs
def diagnose(self):
"""Returns diagnostic plots to assess the most recent call to `correct()`.
If `correct()` has not yet been called, a ``ValueError`` will be raised.
Returns
-------
`~matplotlib.axes.Axes`
The matplotlib axes object.
"""
return self._diagnostic_plot()
def diagnose_priors(self):
"""Returns a diagnostic plot visualizing how the best-fit coefficients
compare against the priors.
The method will show the results obtained during the most recent call
to `correct()`. If `correct()` has not yet been called, a
``ValueError`` will be raised.
Returns
-------
`~matplotlib.axes.Axes`
The matplotlib axes object.
"""
if not hasattr(self, 'corrected_lc'):
raise ValueError('Please call the `correct()` method before trying to diagnose.')
names = [X.name for X in self.X]
with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(1, len(names), figsize=(len(names)*4, 4),
sharey=True)
if not hasattr(axs, '__iter__'):
axs = [axs]
for idx, ax, X in zip(range(len(names)), axs, self.X):
X.plot_priors(ax=ax)
firstcol_idx = sum([m.shape[1] for m in self.X.matrices[:idx]])
submatrix_coefficients = self.coefficients[firstcol_idx:firstcol_idx+X.shape[1]]
[ax.axvline(s, color='red', zorder=-1) for s in submatrix_coefficients]
return axs