-
Notifications
You must be signed in to change notification settings - Fork 35
/
continuum.py
361 lines (296 loc) · 11.5 KB
/
continuum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
""" Module for fitting a QSO continuum
"""
from __future__ import print_function, absolute_import, division, \
unicode_literals
import warnings
import sys, os
import numpy as np
from ..utils import between
from .interp import AkimaSpline
def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, debug=False):
""" Generate a series of wavelength chunks for use by
prepare_knots, assuming a QSO spectrum.
"""
cond = np.isnan(wa)
if np.any(cond):
warnings.warn('Some wavelengths are NaN, ignoring these pixels.')
wa = wa[~cond]
assert len(wa) > 0
zp1 = 1 + redshift
#reflines = np.array([1025.72, 1215.6701, 1240.14, 1398.0,
# 1549.06, 1908, 2800 ])
# generate the edges of wavelength chunks to send to fitting routine
# these edges and divisions are generated by trial and error
# for S/N = 15ish and resolution = 2000ish
div = np.rec.fromrecords([(200. , 500. , 25),
(500. , 800. , 25),
(800. , 1190., 25),
(1190., 1213., 4),
(1213., 1230., 6),
(1230., 1263., 6),
(1263., 1290., 5),
(1290., 1340., 5),
(1340., 1370., 2),
(1370., 1410., 5),
(1410., 1515., 5),
(1515., 1600., 15),
(1600., 1800., 8),
(1800., 1900., 5),
(1900., 1940., 5),
(1940., 2240., 15),
(2240., 3000., 25),
(3000., 6000., 80),
(6000., 20000., 100),
], names=str('left,right,num'))
div.num[2:] = np.ceil(div.num[2:] * divmult)
div.num[:2] = np.ceil(div.num[:2] * forest_divmult)
div.left *= zp1
div.right *= zp1
if debug:
print(div.tolist())
temp = [np.linspace(left, right, n+1)[:-1] for left,right,n in div]
edges = np.concatenate(temp)
i0,i1,i2 = edges.searchsorted([wa[0], 1210*zp1, wa[-1]])
if debug:
print(i0,i1,i2)
return edges[i0:i2]
def update_knots(knots, indices, fl, masked):
""" Calculate the y position of each knot.
Updates `knots` inplace.
Parameters
----------
knots: list of [xpos, ypos, bool] with length N
bool says whether the knot should kept unchanged.
indices: list of (i0,i1) index pairs
The start and end indices into fl and masked of each
spectrum chunk (xpos of each knot are the chunk centres).
fl, masked: arrays shape (M,)
The flux, and boolean arrays showing which pixels are
masked.
"""
iy, iflag = 1, 2
for iknot,(i1,i2) in enumerate(indices):
if knots[iknot][iflag]:
continue
f0 = fl[i1:i2]
m0 = masked[i1:i2]
f1 = f0[~m0]
knots[iknot][iy] = np.median(f1)
def linear_co(wa, knots):
"""linear interpolation through the spline knots.
Add extra points on either end to give
a nice slope at the end points."""
wavc, mfl = list(zip(*knots))[:2]
extwavc = ([wavc[0] - (wavc[1] - wavc[0])] + list(wavc) +
[wavc[-1] + (wavc[-1] - wavc[-2])])
extmfl = ([mfl[0] - (mfl[1] - mfl[0])] + list(mfl) +
[mfl[-1] + (mfl[-1] - mfl[-2])])
co = np.interp(wa, extwavc, extmfl)
return co
def Akima_co(wa, knots):
"""Akima interpolation through the spline knots."""
x,y,_ = zip(*knots)
spl = AkimaSpline(x, y)
return spl(wa)
def remove_bad_knots(knots, indices, masked, fl, er, debug=False):
""" Remove knots in chunks without any good pixels. Modifies
inplace."""
idelknot = []
for iknot,(i,j) in enumerate(indices):
if np.all(masked[i:j]) or np.median(fl[i:j]) <= 2*np.median(er[i:j]):
if debug:
print('Deleting knot', iknot, 'near {:.1f} Angstroms'.format(
knots[iknot][0]))
idelknot.append(iknot)
for i in reversed(idelknot):
del knots[i]
del indices[i]
def chisq_chunk(model, fl, er, masked, indices, knots, chithresh=1.5):
""" Calc chisq per chunk, update knots flags inplace if chisq is
acceptable. """
chisq = []
FLAG = 2
for iknot,(i1,i2) in enumerate(indices):
if knots[iknot][FLAG]:
continue
f0 = fl[i1:i2]
e0 = er[i1:i2]
m0 = masked[i1:i2]
f1 = f0[~m0]
e1 = e0[~m0]
mod0 = model[i1:i2]
mod1 = mod0[~m0]
resid = (mod1 - f1) / e1
chisq = np.sum(resid*resid)
rchisq = chisq / len(f1)
if rchisq < chithresh:
#print (good reduced chisq in knot', iknot)
knots[iknot][FLAG] = True
def prepare_knots(wa, fl, er, edges, ax=None, debug=False):
""" Make initial knots for the continuum estimation.
Parameters
----------
wa, fl, er : arrays
Wavelength, flux, error.
edges : The edges of the wavelength chunks. Splines knots are to be
places at the centre of these chunks.
ax : Matplotlib Axes
If not None, use to plot debugging info.
Returns
-------
knots, indices, masked
* knots: A list of [x, y, flag] lists giving the x and y position
of each knot.
* indices: A list of tuples (i,j) giving the start and end index
of each chunk.
* masked: An array the same shape as wa.
"""
indices = wa.searchsorted(edges)
indices = [(i0,i1) for i0,i1 in zip(indices[:-1],indices[1:])]
wavc = [0.5*(w1 + w2) for w1,w2 in zip(edges[:-1],edges[1:])]
knots = [[wavc[i], 0, False] for i in range(len(wavc))]
masked = np.zeros(len(wa), bool)
masked[~(er > 0)] = True
# remove bad knots
remove_bad_knots(knots, indices, masked, fl, er, debug=debug)
if ax is not None:
yedge = np.interp(edges, wa, fl)
ax.vlines(edges, 0, yedge + 100, color='c', zorder=10)
# set the knot flux values
update_knots(knots, indices, fl, masked)
if ax is not None:
x,y = list(zip(*knots))[:2]
ax.plot(x, y, 'o', mfc='none', mec='c', ms=10, mew=1, zorder=10)
return knots, indices, masked
def unmask(masked, indices, wa, fl, er, minpix=3):
""" Forces each chunk to use at least minpix pixels.
Sometimes all pixels can become masked in a chunk. We don't want
this! This forces there to be at least minpix pixels used in each
chunk.
"""
for iknot,(i,j) in enumerate(indices):
#print(iknot, wa[i], wa[j], (~masked[i:j]).sum())
if np.sum(~masked[i:j]) < minpix:
#print('unmasking pixels')
# need to unmask minpix
f0 = fl[i:j]
e0 = er[i:j]
ind = np.arange(i,j)
f1 = f0[e0 > 0]
isort = np.argsort(f1)
ind1 = ind[e0 > 0][isort[-minpix:]]
# print(wa[i], wa[j])
# print(wa[ind1])
masked[ind1] = False
def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=100,
nsig=1.5, debug=False):
""" Iterate to estimate the continuum.
"""
count = 0
while True:
if debug:
print('iteration', count)
update_knots(knots, indices, s.fl, masked)
model = linear_co(s.wa, knots)
model_a = Akima_co(s.wa, knots)
chisq_chunk(model_a, s.fl, s.er, masked,
indices, knots, chithresh=1)
flags = list(zip(*knots))[-1]
if np.all(flags):
if debug:
print('All regions have satisfactory fit, stopping')
break
# remove outliers
c0 = ~masked
resid = (model - s.fl) / s.er
oldmasked = masked.copy()
masked[(resid > nsig) & ~masked] = True
unmask(masked, indices, s.wa, s.fl, s.er)
if np.all(oldmasked == masked):
if debug:
print('No further points masked, stopping')
break
if count > maxiter:
warnings.warn('Exceeded maximum iterations. Continue at your own risk..')
break
count +=1
co = Akima_co(s.wa, knots)
c0 = co <= 0
co[c0] = 0
if ax is not None:
ax.plot(s.wa, linear_co(s.wa, knots), color='0.7', lw=2)
ax.plot(s.wa, co, 'k', lw=2, zorder=10)
x,y = list(zip(*knots))[:2]
ax.plot(x, y, 'o', mfc='none', mec='k', ms=10, mew=1, zorder=10)
return co
def find_continuum(spec, edges=None, ax=None, debug=False, kind='QSO',
**kwargs):
""" Estimate a continuum for a spectrum.
Parameters
----------
spec: XSpectrum1D object
Wavelength, flux and one sigma error.
kind : {'default', 'QSO'}
Which kind of continuum to fit. This is used to generate a list
of wavelength chunks where spline knots will be placed.
edges: array of float
A list of wavelengths giving the edges of chunks where a spline
knot will be fitted. If this is given, the 'kind' keyword is
ignored.
ax : matplotlib Axes
If this is not None, use ax to make diagnostic plots.
Additional keywords for kind = 'QSO':
redshift: float
QSO emission redshift.
forest_divmult: float
Multiplier for the number of spline knots at wavelengths shorter
than Lya. The default (2) is suitable for UVES/HIRES resoluion
spectra - experiment with smaller values for lower resolution
spectra.
divmult: float
Multiplier for the number of knots at wavelengths longer than
Lya.
Returns
-------
co, contpoints: array of shape (N,) and a list of (x,y) pairs.
co is an estimate for the continuum.
contpoints is a list of (x,y) pairs, giving the position of
spline knots used to generate the continuum. Use
linetools.analysis.interp.AkimaSpline to re-generate the
continuum from these knots.
"""
s = np.rec.fromarrays([spec.wavelength.value,
spec.flux.value,
spec.sig], names=str('wa,fl,er'))
if edges is not None:
edges = list(edges)
elif kind.upper() == 'QSO':
if 'redshift' in kwargs:
z = kwargs['redshift']
elif 'redshift' in spec.meta:
z = spec.meta['redshift']
else:
raise RuntimeError(
"I need the emission redshift for kind='qso'; please\
provide redshift using `redshift` keyword.")
divmult = kwargs.get('divmult', 2)
forest_divmult = kwargs.get('forest_divmult', 2)
edges = make_chunks_qso(
s.wa, z, debug=debug, divmult=divmult,
forest_divmult=forest_divmult)
else:
s = "Kind keyword {:s} unknown. ".format(kind)
s += "Currently only kind='QSO' is supported"
raise NotImplementedError(s)
if ax is not None:
ax.plot(s.wa, s.fl, '-', color='0.4', drawstyle='steps-mid')
ax.plot(s.wa, s.er, 'g')
knots, indices, masked = prepare_knots(s.wa, s.fl, s.er, edges,
ax=ax, debug=debug)
# Note this modifies knots and masked inplace
co = estimate_continuum(s, knots, indices, masked, ax=ax, debug=debug)
if ax is not None:
ax.plot(s.wa[~masked], s.fl[~masked], '.y')
ymax = np.percentile(s.fl[~np.isnan(s.fl)], 95)
ax.set_ylim(-0.02*ymax, 1.1*ymax)
return co, [k[:2] for k in knots]