In [3]:
import pytest

from funcs.notebook import *

from altaipony.flarelc import FlareLightCurve
from altaipony.altai import find_iterative_median
from altaipony.utils import sigma_clip

import copy
from scipy.interpolate import UnivariateSpline
from scipy import optimize
from funcs.detrend import search_gaps_for_window_length, fit_spline, estimate_detrended_noise

import astropy.units as u

  return f(*args, **kwds)
  return f(*args, **kwds)


In [15]:
def custom_detrending(lc, spline_coarseness=30, spline_order=3,
                      savgol1=6., savgol2=3., pad=3):
    """Custom de-trending for TESS and Kepler 
    short cadence light curves, including TESS Cycle 3 20s
    cadence.
    
    Parameters:
    ------------
    lc : FlareLightCurve
        light curve that has at least time, flux and flux_err
    spline_coarseness : float
        time scale in hours for spline points. 
        See fit_spline for details.
    spline_order: int
        Spline order for the coarse spline fit.
        Default is cubic spline.
    savgol1 : float
        Window size for first Savitzky-Golay filter application.
        Unit is hours, defaults to 6 hours.
    savgol2 : float
        Window size for second Savitzky-Golay filter application.
        Unit is hours, defaults to 3 hours.
    pad : 3
        Outliers in Savitzky-Golay filter are padded with this
        number of data points. Defaults to 3.
        
    Return:
    -------
    FlareLightCurve with detrended_flux attribute
    """
    
#     plt.figure(figsize=(16,8))
#     plt.xlim(20,21)
#     plt.plot(lc.time, lc.flux+2500, c="c", label="original light curve")
    
    # fit a spline to the general trends
    lc_, model = fit_spline(lc, spline_order=spline_order,
                             spline_coarseness=spline_coarseness)
    
    # replace for next step
    lc_.flux = lc_.detrended_flux
    
#     plt.plot(lc__.time, model+2500, c="r", label="rough trends")
#     plt.plot(lc__.time, lc__.detrended_flux+500, c="orange", label="rough trends removed")

    # removes strong and fast variability on 5 day to 4.8 hours 
    # simple sines are probably because rotational variability is 
    # either weak and transient or strong and persistent on the timescales
    lc__ = iteratively_remove_sines(lc_)
    
#     plt.plot(lcs.time, lcs.detrended_flux-200, label="sines removed")
    
    # choose a 6 hour window
    w = int((np.rint(savgol1 / 24. / np.mean(np.diff(lc__.time))) // 2) * 2 + 1)

    # use Savitzy-Golay to iron out the rest
    fc___ = lc__.detrend("savgol", window_length=w, pad=pad)

    # choose a three hour window
    w = int((np.rint(savgol2 / 24. / np.mean(np.diff(lc___.time))) // 2) * 2 + 1)

    # use Savitzy-Golay to iron out the rest
    flc__ = flc_.detrend("savgol", window_length=w, pad=pad)
#     plt.plot(flc_.time, flc_.detrended_flux-800, c="k", label="SavGol applied")
    
    # find median value
    flc__ = find_iterative_median(flc__)
    
    # replace for next step
    flc__.flux = flc__.detrended_flux
    
    # remove exopential fringes that neither spline, 
    # nor sines, nor SavGol 
    # can remove.
    flcc = remove_exponential_fringes(flc__)
    
#     plt.plot(flcc.time, flcc.detrended_flux, c="magenta", label="expfunc applied")
    
#     plt.xlim(10,40)
#     plt.xlabel("time [days]")
#     plt.ylabel("flux")
#     plt.legend()
    
    return flcc

In [16]:
def iteratively_remove_sines(flcd, niter=10, freq_unit=1/u.day, 
                             maximum_frequency=12., 
                             minimum_frequency=0.2,
                             max_sigma=3.5, longdecay=2):
    
    """Iteratively remove strong sinusoidal signal
    from light curve. Each iteration calculates a Lomb-Scargle 
    periodogram and LSQ-fits a cosine function using the dominant
    frequency as starting point. 
    
    
    Parameters:
    ------------
    flcd : FlareLightCurve
        light curve from which to remove 
    niter : int
        Maximum number of iterations. 
    freq_unit : astropy.units
        unit in which maximum_frequency and minimum_frequency
        are given
    maximum_frequency: float
        highest frequency to calculate the Lomb-Scargle periodogram
    minimum_frequency: float
        lowest frequency to calculate the Lomb-Scargle periodogram
    max_sigma : float
        Passed to altaipony.utils.sigma_clip. 
        Above this value data points
        are flagged as outliers.
    longdecay : 2
        altaipony.utils.sigma_clip expands the mask for series
        of outliers by sqrt(length of series). Longdecay doubles
        the mask expansion in the decay phase of what may be flares.
        
    Return:
    -------
    FlareLightCurve with detrended_flux attribute
            
    """
    
    # define cosine function
    def cosine(x, a, b, c, d):
        return a * np.cos(b * x + c) + d

    # make a copy of the original LC
    flct = copy.deepcopy(flcd)
    
    # iterate over chunks
    for le, ri in flct.find_gaps().gaps:
        
        # again make a copy of the chunk to manipulate safely
        flc = copy.deepcopy(flct[le:ri])
        
        # find median of LC
        flc = find_iterative_median(flc)
        
        # mask flares
        mask = sigma_clip(flc.flux, max_sigma=3.5, longdecay=2)

        # how many data points comprise the fastest period at maximum_frequency?
        full_fastest_period = 1. / maximum_frequency / np.nanmin(np.diff(flc.remove_nans().time))
        
        # only remove sines if LC chunk is larger than one full period of the fastest frequency
        if flc.flux.shape[0] > full_fastest_period:

            n = 0 # start counter
            snr = 3 # go into while loop at least once
            
            # iterate while there is signal, but not more than n times
            while ((snr > 1) & (n < niter)):
                
                # mask NaNs and outliers
                cond = np.invert(np.isnan(flc.time)) & np.invert(np.isnan(flc.flux)) & mask
                
                # calculate periodogram
                pg = flc[cond].to_periodogram(freq_unit=freq_unit,
                                                      maximum_frequency=maximum_frequency,
                                                      minimum_frequency=minimum_frequency)

                # fit sinusoidal
                p, p_cov = optimize.curve_fit(cosine, flc.time[cond], flc.flux[cond],
                                              p0=[np.nanstd(flc.flux),
                                              2*np.pi*pg.frequency_at_max_power.value,
                                              0, np.nanmean(flc.flux)], ftol=1e-6)
                
                # replace with de-trended flux but without subtracting the median
                flc.flux = flc.flux - cosine(flc.time, p[0], p[1], p[2], 0.)

                # update SNR
                snr = pg.flatten().max_power
                
                # bump iterator
                n += 1
      
            # replace the empty array with the fitted detrended flux
            flcd.detrended_flux[le:ri] = flc.flux
        
    return flcd

In [18]:
cases = [(.05, 0.005, 1.5, 24.4, 1.5, 0.1),
         (.1, 0.005, 1.5, 14.4, 1.5, 0.5),
         (.1, 0.05, 1.5, 8, 1.5, 0.5),
         (.01, .1, 1.5, 8, -.5, 0.25),
         (.3, .05, .5, 30, -.5, 0.25),
         ]

@pytest.mark.parametrize("a1,a2,period1,period2,quad,cube", cases)
def test_custom_detrending(a1, a2, period1, period2, quad, cube,):
    
    # fix uncertainty
    errorval = 15.
    np.random.seed(40)
    lc = generate_lightcurve(errorval, a1, a2, period1, period2, quad, cube)

    lc.plot()
    flcc = custom_detrending(lc)

    flccc = estimate_detrended_noise(flcc, mask_pos_outliers_sigma=2.5, 
                                     std_window=100)

    flccc = find_iterative_median(flccc)


    flares = flccc.find_flares(addtail=True).flares

    # check that uncertainty is 
    assert np.nanmedian(flccc.detrended_flux_err) == pytest.approx(errorval, abs=2)


    compare = pd.DataFrame({'istart': {0: 5280, 1: 13160, 2: 23160},
                            'istop': {0: 5346, 1: 13163, 2: 23175}})
    assert (flares[["istart","istop"]] == compare[["istart","istop"]]).all().all()
    
    assert (flares.ed_rec.values ==
            pytest.approx(np.array([802.25, 4.7907, 40.325]), rel=0.15))

    assert (flares.ampl_rec.values ==
            pytest.approx(np.array([0.28757, 0.03004, 0.064365]), rel=0.15))
    return flccc

In [4]:
def remove_exponential_fringes(flcd, demask=10, max_sigma=3.5, longdecay=2):
    """Remove exponential fringes from light curve chunks.
    
    Parameters:
    -----------
    flcd : FlareLightCurve
        Mostly de-trended light curves 
        with possibly fringy fringes that need a haircut.
    demask : int
        fraction of light curve to keep in the fit even it
        deviates from the median, applied to the end and start
        of each light curve chunk.
    max_sigma : float
        Passed to altaipony.utils.sigma_clip. 
        Above this value data points
        are flagged as outliers.
    longdecay : 2
        altaipony.utils.sigma_clip expands the mask for series
        of outliers by sqrt(length of series). Longdecay doubles
        the mask expansion in the decay phase of what may be flares.
        
    """
   
    def twoexps(x, a, b, c, d, e,f,g):
        return a * np.exp(b * (c - x)) + d * np.exp(e * (f - x)) + g
    
    flct = copy.deepcopy(flcd)
    
    # initiate a detrended flux array
    flct.detrended_flux = np.full_like(flct.flux, np.nan)
    
    for le, ri in flct.find_gaps().gaps:
 
        f_ = copy.deepcopy(flct[le:ri])
    
        # mask outliers 
        mask = sigma_clip(f_.flux, max_sigma=max_sigma, longdecay=longdecay)
        ff = f_[mask]

        # get the median as a guess for the least square fit
        median = np.nanmedian(ff.it_med)
        
        # get noise level from the fully 
        std = np.nanstd(ff.flux)
        
        
        # demask the fringes because they are otherwise likely to be 
        # masked by sigma clipping
        mask[:len(f_.flux) // demask] = 1
        mask[-len(f_.flux) // demask:] = 1
        
        ff = f_[mask]
        
        # get the amplitude of the fringes
        sta, fin = ff.flux[0] - median, ff.flux[-1] - median
        
        # don't fit the fringes if they not even there
        # i.e. smaller than global noise of outlier-free LC
        noleft = np.abs(sta) < std
        noright = np.abs(fin) < std:
        
        # adjust the LSQ function to number of fringes
        # also fix time offset 
        
        # only end of LC chunk fringes
        if (not noright) & (noleft):
            print(noleft, noright)
            def texp(x, d, e, g):
                return twoexps(x, 0., 0., 0., d, e, ff.time[-1], g)
            p0 = [fin, -10., median]
            
        # only start of LC chunk fringes
        if (not noleft) & (noright):
            def texp(x, a, b, g):
                return twoexps(x, a, b, ff.time[0], 0., 0., 0., g)
            p0 = [sta, 10., median]
        
        # no fringes at all
        if (noleft) & (noright):
            def texp(x, g):
                return twoexps(x, 0., 0., 0., 0., 0., 0., g)
            p0 = [median]
        
        # both sides of LC chunk fringe
        if (not noleft) & (not noright):
            def texp(x, a, b, d, e, g):
                return twoexps(x, a, b, ff.time[0], d, e, ff.time[-1], g)
            p0 = [sta, 10., fin, -10., median]

        # do the LSQ fit
        p, p_cov = optimize.curve_fit(texp, ff.time, ff.flux,
                                      p0=p0, sigma=np.full_like(ff.flux, std),
                                      absolute_sigma=True, ftol=1e-6)
        # Remove the fit from the LC
        # median + full flux - model
        nflux = p[-1] + ff.flux - texp(ff.time, *p)

        # replace NaNs in detrended flux with solution
        flcd.detrended_flux[le:ri][mask] = nflux
        
        # re-introduce outliers and flare candidates
        flcd.detrended_flux[le:ri][~mask] = flcd.flux[le:ri][~mask]
        
    return flcd

In [None]:
# test remove_exponential_fringes
x = np.linspace(10,40,1200)
y1 = 1*np.exp(-1*(40-x)*2) + 10
y2 = 1*np.exp((10-x)*2) + 10
# plt.plot(x,y1+y2)
flc = FlareLightCurve(time=x, flux=y1+y2, flux_err=np.full_like(y1, .5))
flc.plot()
median=20
sta, fin = flc.flux[0]-median, flc.flux[-1]-median
print(sta, fin)
plt.plot(flc.time, twoexps(flc.time, sta, np.sign(sta), flc.time[0], fin, -np.sign(fin), flc.time[-1],median))

flcd = remove_exp(flc)

In [None]:
# test iterative sine removal
# test fit spline

In [26]:
2200 / 37.5 / 4.4

13.333333333333332