/
sffcorrector.py
461 lines (388 loc) · 17.8 KB
/
sffcorrector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
"""Defines the `SFFCorrector` class.
`SFFCorrector` enables systematics to be removed from light curves using the
Self Flat-Fielding (SFF) method described in Vanderburg and Johnson (2014).
"""
import logging
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from astropy.modeling import models, fitting
from . import DesignMatrix, DesignMatrixCollection
from .regressioncorrector import RegressionCorrector
from .. import MPLSTYLE
from ..utils import LightkurveWarning
log = logging.getLogger(__name__)
__all__ = ['SFFCorrector']
class SFFCorrector(RegressionCorrector):
"""Special case of `.RegressionCorrector` where the `.DesignMatrix` includes
the target's centroid positions.
The design matrix also contains columns representing a spline in time
design to capture the intrinsic, long-term variability of the target.
Parameters
----------
lc : `.LightCurve`
The light curve that needs to be corrected.
"""
def __init__(self, lc):
if getattr(lc, 'mission', '') == 'TESS':
warnings.warn("The SFF correction method is not suitable for use "
"with TESS data, because the spacecraft motion does "
"not proceed along a consistent arc.",
LightkurveWarning)
self.raw_lc = lc
if hasattr(lc, 'flux_unit'):
if lc.flux_unit is None:
lc = lc.copy()
elif lc.flux_unit.to_string() == '':
lc = lc.copy()
else:
lc = lc.copy().normalize()
else:
lc = lc.copy().normalize()
# Setting these values as None so we don't get a value error if the
# user tries to access them before "correct()"
self.window_points = None
self.windows = None
self.bins = None
self.timescale = None
self.breakindex = None
self.centroid_col = None
self.centroid_row = None
super(SFFCorrector, self).__init__(lc=lc)
def __repr__(self):
return 'SFFCorrector (LC: {})'.format(self.lc.targetid)
def correct(self, centroid_col=None, centroid_row=None, windows=20, bins=5,
timescale=1.5, breakindex=None, degree=3, restore_trend=False,
additional_design_matrix=None, polyorder=None, **kwargs):
"""Find the best fit correction for the light curve.
Parameters
----------
centroid_col : np.ndarray of floats (optional)
Array of centroid column positions. If ``None``, will use the
`centroid_col` attribute of the input light curve by default.
centroid_row : np.ndarray of floats (optional)
Array of centroid row positions. If ``None``, will use the
`centroid_row` attribute of the input light curve by default.
windows : int
Number of windows to split the data into to perform the correction.
Default 20.
bins : int
Number of "knots" to place on the arclength spline. More bins will
increase the number of knots, making the spline smoother in arclength.
Default 10.
timescale: float
Time scale of the b-spline fit to the light curve in time, in units
of input light curve time.
breakindex : None, int or list of ints (optional)
Optionally the user can break the light curve into sections. Set
break index to either an index at which to break, or list of indicies.
degree : int
The degree of polynomials in the splines in time and arclength. Higher
values will create smoother splines. Default 3.
cadence_mask : np.ndarray of bools (optional)
Mask, where True indicates a cadence that should be used.
sigma : int (default 5)
Standard deviation at which to remove outliers from fitting
niters : int (default 5)
Number of iterations to fit and remove outliers
restore_trend : bool (default False)
Whether to restore the long term spline trend to the light curve
propagate_errors : bool (default False)
Whether to propagate the uncertainties from the regression. Default is False.
Setting to True will increase run time, but will sample from multivariate normal
distribution of weights.
additional_design_matrix : `~lightkurve.lightcurve.Correctors.DesignMatrix` (optional)
Additional design matrix to remove, e.g. containing background vectors.
polyorder : int
Deprecated as of Lightkurve v1.4. Use ``degree`` instead.
Returns
-------
corrected_lc : `~lightkurve.lightcurve.LightCurve`
Corrected light curve, with noise removed.
"""
from patsy import dmatrix # local import because it's rarely-used
if polyorder is not None:
warnings.warn("`polyorder` is deprecated and no longer used, "
"please use the `degree` keyword instead.",
LightkurveWarning)
if centroid_col is None:
centroid_col = self.lc.centroid_col
if centroid_row is None:
centroid_row = self.lc.centroid_row
if np.any([~np.isfinite(centroid_row), ~np.isfinite(centroid_col)]):
raise ValueError('Centroids contain NaN values.')
self.window_points = _get_window_points(centroid_col, centroid_row,
windows, breakindex=breakindex)
self.windows = windows
self.bins = bins
self.timescale = timescale
self.breakindex = breakindex
self.arclength = _estimate_arclength(centroid_col, centroid_row)
lower_idx = np.asarray(np.append(0, self.window_points), int)
upper_idx = np.asarray(np.append(self.window_points, len(self.lc.time)), int)
stack = []
columns = []
prior_sigmas = []
for idx, a, b in zip(range(len(lower_idx)), lower_idx, upper_idx):
knots = list(np.percentile(self.arclength[a:b], np.linspace(0, 100, bins+1)[1:-1]))
ar = np.copy(self.arclength)
ar[~np.in1d(ar, ar[a:b])] = 0
dm = np.asarray(dmatrix("bs(x, knots={}, degree={}, include_intercept={}) - 1"
"".format(knots, degree, True), {"x": ar}))
stack.append(dm)
columns.append(['window{}_bin{}'.format(idx+1, jdx+1)
for jdx in range(len(dm.T))])
# I'm putting VERY weak priors on the SFF motion vectors
# (1e-6 is being added to prevent sigma from being zero)
ps = np.ones(len(dm.T)) * 10000 * self.lc[a:b].flux.std() + 1e-6
prior_sigmas.append(ps)
sff_dm = DesignMatrix(pd.DataFrame(np.hstack(stack)),
columns=np.hstack(columns),
name='sff',
prior_sigma=np.hstack(prior_sigmas))
# long term
n_knots = int((self.lc.time[-1] - self.lc.time[0])/timescale)
s_dm = _get_spline_dm(self.lc.time, n_knots=n_knots, include_intercept=True)
means = [np.average(self.lc.flux, weights=s_dm.values[:, idx]) for idx in range(s_dm.shape[1])]
s_dm.prior_mu = np.asarray(means)
# I'm putting WEAK priors on the spline that it must be around 1
s_dm.prior_sigma = np.ones(len(s_dm.prior_mu)) * 1000 * self.lc.flux.std() + 1e-6
# additional
if additional_design_matrix is not None:
if not isinstance(additional_design_matrix, DesignMatrix):
raise ValueError('`additional_design_matrix` must be a DesignMatrix object.')
self.additional_design_matrix = additional_design_matrix
dm = DesignMatrixCollection([s_dm,
sff_dm,
additional_design_matrix])
else:
dm = DesignMatrixCollection([s_dm, sff_dm])
# correct
clc = super(SFFCorrector, self).correct(dm, **kwargs)
# clean
if restore_trend:
trend = self.diagnostic_lightcurves['spline'].flux
clc += trend - np.nanmedian(trend)
clc *= self.raw_lc.flux.mean()
return clc
def diagnose(self):
"""Returns a diagnostic plot which visualizes what happened during the
most recent call to `correct()`."""
axs = self._diagnostic_plot()
for t in self.window_points:
axs[0].axvline(self.lc.time[t], color='r', ls='--', alpha=0.3)
def diagnose_arclength(self):
"""Returns a diagnostic plot which visualizes arclength vs flux
from most recent call to `correct()`."""
max_plot = 5
with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(int(np.ceil(self.windows/max_plot)), max_plot,
figsize=(10, int(np.ceil(self.windows/max_plot)*2)),
sharex=True, sharey=True)
axs = np.atleast_2d(axs)
axs[0, 2].set_title('Arclength Plot/Window')
plt.subplots_adjust(hspace=0, wspace=0)
lower_idx = np.asarray(np.append(0, self.window_points), int)
upper_idx = np.asarray(np.append(self.window_points, len(self.lc.time)), int)
if hasattr(self, 'additional_design_matrix'):
name = self.additional_design_matrix.name
f = (self.lc.flux - self.diagnostic_lightcurves['spline'].flux
- self.diagnostic_lightcurves[name].flux)
else:
f = (self.lc.flux - self.diagnostic_lightcurves['spline'].flux)
m = self.diagnostic_lightcurves['sff'].flux
idx, jdx = 0, 0
for a, b in zip(lower_idx, upper_idx):
ax = axs[idx, jdx]
if jdx == 0:
ax.set_ylabel('Flux')
ax.scatter(self.arclength[a:b], f[a:b], s=1, label='Data')
ax.scatter(self.arclength[a:b][~self.cadence_mask[a:b]],
f[a:b][~self.cadence_mask[a:b]],
s=10, marker='x', c='r', label='Outliers')
s = np.argsort(self.arclength[a:b])
ax.scatter(self.arclength[a:b][s],
(m[a:b] - np.median(m[a:b]) + np.median(f[a:b]))[s],
c='C2', s=0.5, label='Model')
jdx += 1
if jdx >= max_plot:
jdx = 0
idx += 1
if b == len(self.lc.time):
ax.legend()
######################
# Helper functions #
######################
def _get_spline_dm(x, n_knots=20, degree=3, name='spline',
include_intercept=False):
"""Returns a `.DesignMatrix` which models splines using `patsy.dmatrix`.
Parameters
----------
x : np.ndarray
vector to spline
n_knots: int
Number of knots (default: 20).
degree: int
Polynomial degree
name: string
Name to pass to `.DesignMatrix` (default: 'spline').
include_intercept: bool
Whether to include row of ones to find intercept. Default False.
Returns
-------
dm: `.DesignMatrix`
Design matrix object with shape (len(x), n_knots*degree).
"""
from patsy import dmatrix # local import because it's rarely-used
dm_formula = "bs(x, df={}, degree={}, include_intercept={}) - 1" \
"".format(n_knots, degree, include_intercept)
spline_dm = np.asarray(dmatrix(dm_formula, {"x": x}))
df = pd.DataFrame(spline_dm, columns=['knot{}'.format(idx + 1)
for idx in range(n_knots)])
return DesignMatrix(df, name=name)
def _get_centroid_dm(col, row, name='centroids'):
"""Returns a `.DesignMatrix` containing (col, row) centroid positions
and transformations thereof.
Parameters
----------
col : np.ndarray
centroid column
row : np.ndarray
centroid row
name : str
Name to pass to `.DesignMatrix` (default: 'centroids').
Returns
-------
dm: np.ndarray
Design matrix with shape len(c) x 10
"""
data = [col, row,
col**2, row**2,
col**3, row**3,
col*row,
col**2 * row, col * row**2,
col**2 * row**2]
names = [r'col', r'row',
r'col^2', r'row^2',
r'col^3', r'row^3',
r'col \times row',
r'col^2 \times row', r'col \times row^2',
r'col^2 \times row^2']
df = pd.DataFrame(np.asarray(data).T, columns=names)
return DesignMatrix(df, name=name)
def _get_thruster_firings(arclength):
"""Find locations where K2 fired thrusters
Parameters
----------
arc : np.ndarray
arclength as a function of time
Returns
-------
thrusters: np.ndarray of bools
True at times where thrusters were fired.
"""
arc = np.copy(arclength)
# Rate of change of rate of change of arclength wrt time
d2adt2 = (np.gradient(np.gradient(arc)))
# Fit a Gaussian, most points lie in a tight region, thruster firings are outliers
g = models.Gaussian1D(amplitude=100, mean=0, stddev=0.01)
fitter = fitting.LevMarLSQFitter()
h = np.histogram(d2adt2[np.isfinite(d2adt2)], np.arange(-0.5, 0.5, 0.0001), density=True)
xbins = h[1][1:] - np.median(np.diff(h[1]))
g = fitter(g, xbins, h[0], weights=h[0]**0.5)
# Depending on the orientation of the roll, it is hard to return
# the point before the firing or the point after the firing.
# This makes sure we always return the same value, no matter the roll orientation.
def _start_and_end(start_or_end):
"""Find points at the start or end of a roll."""
if start_or_end == 'start':
thrusters = (d2adt2 < (g.stddev * -5)) & np.isfinite(d2adt2)
if start_or_end == 'end':
thrusters = (d2adt2 > (g.stddev * 5)) & np.isfinite(d2adt2)
# Pick the best thruster in each cluster
idx = np.array_split(np.arange(len(thrusters)),
np.where(np.gradient(np.asarray(thrusters, int)) == 0)[0])
m = np.array_split(thrusters, np.where(np.gradient(np.asarray(thrusters, int)) == 0)[0])
th = []
for jdx, _ in enumerate(idx):
if m[jdx].sum() == 0:
th.append(m[jdx])
else:
th.append((np.abs(np.gradient(arc)[idx[jdx]]) == np.abs(np.gradient(arc)[idx[jdx]][m[jdx]]).max()) & m[jdx])
thrusters = np.hstack(th)
return thrusters
# Get the start and end points
thrusters = np.asarray([_start_and_end('start'), _start_and_end('end')])
thrusters = thrusters.any(axis=0)
# Take just the first point.
thrusters = (np.gradient(np.asarray(thrusters, int)) >= 0) & thrusters
return thrusters
def _get_window_points(centroid_col, centroid_row, windows, arclength=None, breakindex=None):
"""Returns indices where thrusters are fired.
Parameters
----------
lc : `.LightCurve` object
Input light curve
windows: int
Number of windows to split the light curve into
arc: np.ndarray
Arclength for the roll motion
breakindex: int
Cadence where there is a natural break. Windows will be automatically put here.
"""
if arclength is None:
arclength = _estimate_arclength(centroid_col, centroid_row)
# Validate break indices
if isinstance(breakindex, int):
breakindexes = [breakindex]
if breakindex is None:
breakindexes = []
elif (breakindex[0] == 0) & (len(breakindex) == 1):
breakindexes = []
else:
breakindexes = breakindex
if not isinstance(breakindexes, list):
raise ValueError('`breakindex` must be an int or a list')
# If the user asks for break indices we should still return them,
# even if there is only 1 window.
if windows == 1:
return breakindexes
# Find evenly spaced window points
dt = len(centroid_col) / windows
lower_idx = np.append(0, breakindexes)
upper_idx = np.append(breakindexes, len(centroid_col))
window_points = np.hstack([np.asarray(np.arange(a, b, dt), int)
for a, b in zip(lower_idx, upper_idx)])
# Get thruster firings
thrusters = _get_thruster_firings(arclength)
for b in breakindexes:
thrusters[b] = True
thrusters = np.where(thrusters)[0]
# Find the nearest point to each thruster firing, unless it's a user supplied break point
if len(thrusters) > 0:
window_points = [thrusters[np.argmin(np.abs(thrusters - wp))] + 1
for wp in window_points
if wp not in breakindexes]
window_points = np.unique(np.hstack([window_points, breakindexes]))
# If the first or last windows are very short (<40% median window length),
# then we add them to the second or penultimate window, respectively,
# by removing their break points.
median_length = np.median(np.diff(window_points))
if window_points[0] < 0.4*median_length:
window_points = window_points[1:]
if window_points[-1] > (len(centroid_col) - 0.4*median_length):
window_points = window_points[:-1]
return np.asarray(window_points, dtype=int)
def _estimate_arclength(centroid_col, centroid_row):
"""Estimate the arclength given column and row centroid positions.
We use the approximation that the arclength equals
(row**2 + col**2)**0.5
For this to work, row and column must be correlated not anticorrelated.
"""
col = centroid_col - np.nanmin(centroid_col)
row = centroid_row - np.nanmin(centroid_row)
# Force c to be correlated not anticorrelated
if np.polyfit(col, row, 1)[0] < 0:
col = np.nanmax(col) - col
return (col**2 + row**2)**0.5