Skip to content

Commit

Permalink
fixed for new lightkurve units
Browse files Browse the repository at this point in the history
  • Loading branch information
christinahedges committed Nov 30, 2020
1 parent ab7ac16 commit a9b17d4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 27 deletions.
71 changes: 54 additions & 17 deletions docs/TESS-SIP.ipynb

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions src/tess_sip/tess_sip.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def SIP(tpfs, sigma=5, min_period=10, max_period=100, nperiods=300):
"""

# Get the un-background subtracted data
tpfs_uncorr = [(tpf + np.nan_to_num(tpf.flux_bkg))[np.isfinite(tpf.flux_bkg.sum(axis=(1, 2)))] for tpf in tpfs]
tpfs_uncorr = [(tpf + np.nan_to_num(tpf.flux_bkg.value))[np.isfinite(tpf.flux_bkg.value.sum(axis=(1, 2)))] for tpf in tpfs]
apers = [tpf.pipeline_mask for tpf in tpfs_uncorr]
lc = lk.LightCurveCollection([tpf.to_lightcurve(aperture_mask=aper) for tpf, aper in zip(tpfs_uncorr, apers)]).stitch()
lc.flux_err[~np.isfinite(lc.flux_err)] = np.nanmedian(lc.flux_err)
lc.flux_err.value[~np.isfinite(lc.flux_err.value)] = np.nanmedian(lc.flux_err.value)

with warnings.catch_warnings():
warnings.simplefilter('ignore')
bkgs = [lk.DesignMatrix(tpf.flux[:, ~aper], name='bkg').pca(3).append_constant().to_sparse() for tpf, aper in zip(tpfs_uncorr, apers)]
bkgs = [lk.DesignMatrix(tpf.flux.value[:, ~aper], name='bkg').pca(3).append_constant().to_sparse() for tpf, aper in zip(tpfs_uncorr, apers)]
for bkg in bkgs:
bkg.prior_mu[-1] = 1
bkg.prior_sigma[-1] = 0.1
Expand All @@ -85,17 +85,17 @@ def SIP(tpfs, sigma=5, min_period=10, max_period=100, nperiods=300):
bkg.prior_sigma[:-1] = 0.1

# Split at the datadownlink
bkgs = [bkg.split(list((np.where((np.diff(tpf.time) > 0.3))[0] + 1))) for bkg, tpf in zip(bkgs, tpfs_uncorr)]
bkgs = [bkg.split(list((np.where(np.diff(tpf.time.jd) > 0.3)[0] + 1))) for bkg, tpf in zip(bkgs, tpfs_uncorr)]
systematics_dm = vstack(bkgs)

sigma_f_inv = sparse.csr_matrix(1/lc.flux_err[:, None]**2)
sigma_f_inv = sparse.csr_matrix(1/lc.flux_err.value[:, None]**2)
def fit_model(mask=None, return_model=False):
if mask is None:
mask = np.ones(len(lc.flux), bool)
mask = np.ones(len(lc.flux.value), bool)
sigma_w_inv = dm.X[mask].T.dot(dm.X[mask].multiply(sigma_f_inv[mask])).toarray()
sigma_w_inv += np.diag(1. / dm.prior_sigma**2)

B = dm.X[mask].T.dot((lc.flux[mask]/lc.flux_err[mask]**2))
B = dm.X[mask].T.dot((lc.flux.value[mask]/lc.flux_err.value[mask]**2))
B += dm.prior_mu/dm.prior_sigma**2
w = np.linalg.solve(sigma_w_inv, B)
werr = ((np.linalg.inv(sigma_w_inv))**0.5).diagonal()
Expand All @@ -105,7 +105,7 @@ def fit_model(mask=None, return_model=False):

# Make a dummy design matrix
period = 27
ls_dm = lk.DesignMatrix(lombscargle.implementations.mle.design_matrix(lc.time, frequency=1/period, bias=False, nterms=1), name='LS').to_sparse()
ls_dm = lk.DesignMatrix(lombscargle.implementations.mle.design_matrix(lc.time.jd, frequency=1/period, bias=False, nterms=1), name='LS').to_sparse()
dm = lk.SparseDesignMatrixCollection([systematics_dm, ls_dm]).to_designmatrix(name='design_matrix')

# Do a first pass at 50 days, just to find ridiculous outliers
Expand All @@ -118,12 +118,12 @@ def fit_model(mask=None, return_model=False):
ws_err = np.zeros((len(periods), dm.X.shape[1]))

for idx, period in enumerate(tqdm(periods)):
dm.X[:, -ls_dm.shape[1]:] = lombscargle.implementations.mle.design_matrix(lc.time, frequency=1/period, bias=False, nterms=1)
dm.X[:, -ls_dm.shape[1]:] = lombscargle.implementations.mle.design_matrix(lc.time.jd, frequency=1/period, bias=False, nterms=1)
ws[idx], ws_err[idx] = fit_model(mask=mask)
power = (ws[:, -2]**2 + ws[:, -1]**2)**0.5

am = np.argmax(power)
dm.X[:, -ls_dm.shape[1]:] = lombscargle.implementations.mle.design_matrix(lc.time, frequency=1/periods[am], bias=False, nterms=1)
dm.X[:, -ls_dm.shape[1]:] = lombscargle.implementations.mle.design_matrix(lc.time.jd, frequency=1/periods[am], bias=False, nterms=1)
mod = dm.X[:, :-2].dot(ws[am][:-2])


Expand Down

0 comments on commit a9b17d4

Please sign in to comment.